Skip to content

Commit cda03db

Browse files
authored
Merge pull request #11 from asmeurer/alias-cleanups
Pass keyword arguments through in the aliases
2 parents 88e3d77 + 055f793 commit cda03db

File tree

4 files changed

+126
-126
lines changed

4 files changed

+126
-126
lines changed

array_api_compat/common/_aliases.py

Lines changed: 59 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,6 @@
1414

1515
from ._helpers import _check_device, _is_numpy_array, get_namespace
1616

17-
# Basic renames
18-
def acos(x, /, xp):
19-
return xp.arccos(x)
20-
21-
def acosh(x, /, xp):
22-
return xp.arccosh(x)
23-
24-
def asin(x, /, xp):
25-
return xp.arcsin(x)
26-
27-
def asinh(x, /, xp):
28-
return xp.arcsinh(x)
29-
30-
def atan(x, /, xp):
31-
return xp.arctan(x)
32-
33-
def atan2(x1, x2, /, xp):
34-
return xp.arctan2(x1, x2)
35-
36-
def atanh(x, /, xp):
37-
return xp.arctanh(x)
38-
39-
def bitwise_left_shift(x1, x2, /, xp):
40-
return xp.left_shift(x1, x2)
41-
42-
def bitwise_invert(x, /, xp):
43-
return xp.invert(x)
44-
45-
def bitwise_right_shift(x1, x2, /, xp):
46-
return xp.right_shift(x1, x2)
47-
48-
def concat(arrays: Union[Tuple[ndarray, ...], List[ndarray]], /, xp, *, axis: Optional[int] = 0) -> ndarray:
49-
return xp.concatenate(arrays, axis=axis)
50-
51-
def pow(x1, x2, /, xp):
52-
return xp.power(x1, x2)
53-
5417
# These functions are modified from the NumPy versions.
5518

5619
def arange(
@@ -62,25 +25,28 @@ def arange(
6225
xp,
6326
dtype: Optional[Dtype] = None,
6427
device: Optional[Device] = None,
28+
**kwargs
6529
) -> ndarray:
6630
_check_device(xp, device)
67-
return xp.arange(start, stop=stop, step=step, dtype=dtype)
31+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
6832

6933
def empty(
7034
shape: Union[int, Tuple[int, ...]],
7135
xp,
7236
*,
7337
dtype: Optional[Dtype] = None,
7438
device: Optional[Device] = None,
39+
**kwargs
7540
) -> ndarray:
7641
_check_device(xp, device)
77-
return xp.empty(shape, dtype=dtype)
42+
return xp.empty(shape, dtype=dtype, **kwargs)
7843

7944
def empty_like(
80-
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
45+
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
46+
**kwargs
8147
) -> ndarray:
8248
_check_device(xp, device)
83-
return xp.empty_like(x, dtype=dtype)
49+
return xp.empty_like(x, dtype=dtype, **kwargs)
8450

8551
def eye(
8652
n_rows: int,
@@ -91,9 +57,10 @@ def eye(
9157
k: int = 0,
9258
dtype: Optional[Dtype] = None,
9359
device: Optional[Device] = None,
60+
**kwargs,
9461
) -> ndarray:
9562
_check_device(xp, device)
96-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype)
63+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
9764

9865
def full(
9966
shape: Union[int, Tuple[int, ...]],
@@ -102,9 +69,10 @@ def full(
10269
*,
10370
dtype: Optional[Dtype] = None,
10471
device: Optional[Device] = None,
72+
**kwargs,
10573
) -> ndarray:
10674
_check_device(xp, device)
107-
return xp.full(shape, fill_value, dtype=dtype)
75+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
10876

10977
def full_like(
11078
x: ndarray,
@@ -114,9 +82,10 @@ def full_like(
11482
xp,
11583
dtype: Optional[Dtype] = None,
11684
device: Optional[Device] = None,
85+
**kwargs,
11786
) -> ndarray:
11887
_check_device(xp, device)
119-
return xp.full_like(x, fill_value, dtype=dtype)
88+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
12089

12190
def linspace(
12291
start: Union[int, float],
@@ -128,41 +97,46 @@ def linspace(
12897
dtype: Optional[Dtype] = None,
12998
device: Optional[Device] = None,
13099
endpoint: bool = True,
100+
**kwargs,
131101
) -> ndarray:
132102
_check_device(xp, device)
133-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
103+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
134104

135105
def ones(
136106
shape: Union[int, Tuple[int, ...]],
137107
xp,
138108
*,
139109
dtype: Optional[Dtype] = None,
140110
device: Optional[Device] = None,
111+
**kwargs,
141112
) -> ndarray:
142113
_check_device(xp, device)
143-
return xp.ones(shape, dtype=dtype)
114+
return xp.ones(shape, dtype=dtype, **kwargs)
144115

145116
def ones_like(
146-
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
117+
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
118+
**kwargs,
147119
) -> ndarray:
148120
_check_device(xp, device)
149-
return xp.ones_like(x, dtype=dtype)
121+
return xp.ones_like(x, dtype=dtype, **kwargs)
150122

151123
def zeros(
152124
shape: Union[int, Tuple[int, ...]],
153125
xp,
154126
*,
155127
dtype: Optional[Dtype] = None,
156128
device: Optional[Device] = None,
129+
**kwargs,
157130
) -> ndarray:
158131
_check_device(xp, device)
159-
return xp.zeros(shape, dtype=dtype)
132+
return xp.zeros(shape, dtype=dtype, **kwargs)
160133

161134
def zeros_like(
162-
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
135+
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
136+
**kwargs,
163137
) -> ndarray:
164138
_check_device(xp, device)
165-
return xp.zeros_like(x, dtype=dtype)
139+
return xp.zeros_like(x, dtype=dtype, **kwargs)
166140

167141
# np.unique() is split into four functions in the array API:
168142
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
@@ -256,8 +230,9 @@ def std(
256230
axis: Optional[Union[int, Tuple[int, ...]]] = None,
257231
correction: Union[int, float] = 0.0, # correction instead of ddof
258232
keepdims: bool = False,
233+
**kwargs,
259234
) -> ndarray:
260-
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims)
235+
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
261236

262237
def var(
263238
x: ndarray,
@@ -267,8 +242,9 @@ def var(
267242
axis: Optional[Union[int, Tuple[int, ...]]] = None,
268243
correction: Union[int, float] = 0.0, # correction instead of ddof
269244
keepdims: bool = False,
245+
**kwargs,
270246
) -> ndarray:
271-
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims)
247+
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
272248

273249
# Unlike transpose(), the axes argument to permute_dims() is required.
274250
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
@@ -292,6 +268,7 @@ def _asarray(
292268
device: Optional[Device] = None,
293269
copy: "Optional[Union[bool, np._CopyMode]]" = None,
294270
namespace = None,
271+
**kwargs,
295272
) -> ndarray:
296273
"""
297274
Array API compatibility wrapper for asarray().
@@ -333,33 +310,39 @@ def _asarray(
333310
return xp.array(obj, copy=True, dtype=dtype)
334311
return obj
335312

336-
return xp.asarray(obj, dtype=dtype)
313+
return xp.asarray(obj, dtype=dtype, **kwargs)
337314

338315
# xp.reshape calls the keyword argument 'newshape' instead of 'shape'
339-
def reshape(x: ndarray, /, shape: Tuple[int, ...], xp, copy: Optional[bool] = None) -> ndarray:
316+
def reshape(x: ndarray,
317+
/,
318+
shape: Tuple[int, ...],
319+
xp, copy: Optional[bool] = None,
320+
**kwargs) -> ndarray:
340321
if copy is True:
341322
x = x.copy()
342323
elif copy is False:
343324
x.shape = shape
344325
return x
345-
return xp.reshape(x, shape)
326+
return xp.reshape(x, shape, **kwargs)
346327

347328
# The descending keyword is new in sort and argsort, and 'kind' replaced with
348329
# 'stable'
349330
def argsort(
350-
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True
331+
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
332+
**kwargs,
351333
) -> ndarray:
352334
# Note: this keyword argument is different, and the default is different.
353335
kind = "stable" if stable else "quicksort"
354336
if not descending:
355-
res = xp.argsort(x, axis=axis, kind=kind)
337+
res = xp.argsort(x, axis=axis, kind=kind, **kwargs)
356338
else:
357339
# As NumPy has no native descending sort, we imitate it here. Note that
358340
# simply flipping the results of xp.argsort(x, ...) would not
359341
# respect the relative order like it would in native descending sorts.
360342
res = xp.flip(
361343
xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind),
362344
axis=axis,
345+
**kwargs,
363346
)
364347
# Rely on flip()/argsort() to validate axis
365348
normalised_axis = axis if axis >= 0 else x.ndim + axis
@@ -368,11 +351,12 @@ def argsort(
368351
return res
369352

370353
def sort(
371-
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True
354+
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
355+
**kwargs,
372356
) -> ndarray:
373357
# Note: this keyword argument is different, and the default is different.
374358
kind = "stable" if stable else "quicksort"
375-
res = xp.sort(x, axis=axis, kind=kind)
359+
res = xp.sort(x, axis=axis, kind=kind, **kwargs)
376360
if descending:
377361
res = xp.flip(res, axis=axis)
378362
return res
@@ -386,11 +370,12 @@ def sum(
386370
axis: Optional[Union[int, Tuple[int, ...]]] = None,
387371
dtype: Optional[Dtype] = None,
388372
keepdims: bool = False,
373+
**kwargs,
389374
) -> ndarray:
390375
# `xp.sum` already upcasts integers, but not floats
391376
if dtype is None and x.dtype == xp.float32:
392377
dtype = xp.float64
393-
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
378+
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
394379

395380
def prod(
396381
x: ndarray,
@@ -400,32 +385,30 @@ def prod(
400385
axis: Optional[Union[int, Tuple[int, ...]]] = None,
401386
dtype: Optional[Dtype] = None,
402387
keepdims: bool = False,
388+
**kwargs,
403389
) -> ndarray:
404390
if dtype is None and x.dtype == xp.float32:
405391
dtype = xp.float64
406-
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims)
392+
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
407393

408394
# ceil, floor, and trunc return integers for integer inputs
409395

410-
def ceil(x: ndarray, /, xp) -> ndarray:
396+
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
411397
if xp.issubdtype(x.dtype, xp.integer):
412398
return x
413-
return xp.ceil(x)
399+
return xp.ceil(x, **kwargs)
414400

415-
def floor(x: ndarray, /, xp) -> ndarray:
401+
def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
416402
if xp.issubdtype(x.dtype, xp.integer):
417403
return x
418-
return xp.floor(x)
404+
return xp.floor(x, **kwargs)
419405

420-
def trunc(x: ndarray, /, xp) -> ndarray:
406+
def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
421407
if xp.issubdtype(x.dtype, xp.integer):
422408
return x
423-
return xp.trunc(x)
424-
425-
__all__ = ['acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh',
426-
'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift',
427-
'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult',
428-
'UniqueInverseResult', 'unique_all', 'unique_counts',
429-
'unique_inverse', 'unique_values', 'astype', 'std', 'var',
430-
'permute_dims', 'reshape', 'argsort', 'sort', 'sum', 'prod',
431-
'ceil', 'floor', 'trunc']
409+
return xp.trunc(x, **kwargs)
410+
411+
__all__ = ['UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
412+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
413+
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
414+
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc']

0 commit comments

Comments
 (0)