Skip to content

Commit 055f793

Browse files
committed
Pass **kwargs through to wrapped functions
This allows passing NumPy-specific keyword arguments through, like asarray(order='F'). This won't work with keyword arguments that are renamed. This is not done for new functions that aren't in NumPy, like unique_all, permute_dims, matrix_norm, etc. Fixes ##10.
1 parent f4748e5 commit 055f793

File tree

2 files changed

+93
-57
lines changed

2 files changed

+93
-57
lines changed

array_api_compat/common/_aliases.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,28 @@ def arange(
2525
xp,
2626
dtype: Optional[Dtype] = None,
2727
device: Optional[Device] = None,
28+
**kwargs
2829
) -> ndarray:
2930
_check_device(xp, device)
30-
return xp.arange(start, stop=stop, step=step, dtype=dtype)
31+
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
3132

3233
def empty(
3334
shape: Union[int, Tuple[int, ...]],
3435
xp,
3536
*,
3637
dtype: Optional[Dtype] = None,
3738
device: Optional[Device] = None,
39+
**kwargs
3840
) -> ndarray:
3941
_check_device(xp, device)
40-
return xp.empty(shape, dtype=dtype)
42+
return xp.empty(shape, dtype=dtype, **kwargs)
4143

4244
def empty_like(
43-
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
45+
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
46+
**kwargs
4447
) -> ndarray:
4548
_check_device(xp, device)
46-
return xp.empty_like(x, dtype=dtype)
49+
return xp.empty_like(x, dtype=dtype, **kwargs)
4750

4851
def eye(
4952
n_rows: int,
@@ -54,9 +57,10 @@ def eye(
5457
k: int = 0,
5558
dtype: Optional[Dtype] = None,
5659
device: Optional[Device] = None,
60+
**kwargs,
5761
) -> ndarray:
5862
_check_device(xp, device)
59-
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype)
63+
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
6064

6165
def full(
6266
shape: Union[int, Tuple[int, ...]],
@@ -65,9 +69,10 @@ def full(
6569
*,
6670
dtype: Optional[Dtype] = None,
6771
device: Optional[Device] = None,
72+
**kwargs,
6873
) -> ndarray:
6974
_check_device(xp, device)
70-
return xp.full(shape, fill_value, dtype=dtype)
75+
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
7176

7277
def full_like(
7378
x: ndarray,
@@ -77,9 +82,10 @@ def full_like(
7782
xp,
7883
dtype: Optional[Dtype] = None,
7984
device: Optional[Device] = None,
85+
**kwargs,
8086
) -> ndarray:
8187
_check_device(xp, device)
82-
return xp.full_like(x, fill_value, dtype=dtype)
88+
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
8389

8490
def linspace(
8591
start: Union[int, float],
@@ -91,41 +97,46 @@ def linspace(
9197
dtype: Optional[Dtype] = None,
9298
device: Optional[Device] = None,
9399
endpoint: bool = True,
100+
**kwargs,
94101
) -> ndarray:
95102
_check_device(xp, device)
96-
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
103+
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
97104

98105
def ones(
99106
shape: Union[int, Tuple[int, ...]],
100107
xp,
101108
*,
102109
dtype: Optional[Dtype] = None,
103110
device: Optional[Device] = None,
111+
**kwargs,
104112
) -> ndarray:
105113
_check_device(xp, device)
106-
return xp.ones(shape, dtype=dtype)
114+
return xp.ones(shape, dtype=dtype, **kwargs)
107115

108116
def ones_like(
109-
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
117+
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
118+
**kwargs,
110119
) -> ndarray:
111120
_check_device(xp, device)
112-
return xp.ones_like(x, dtype=dtype)
121+
return xp.ones_like(x, dtype=dtype, **kwargs)
113122

114123
def zeros(
115124
shape: Union[int, Tuple[int, ...]],
116125
xp,
117126
*,
118127
dtype: Optional[Dtype] = None,
119128
device: Optional[Device] = None,
129+
**kwargs,
120130
) -> ndarray:
121131
_check_device(xp, device)
122-
return xp.zeros(shape, dtype=dtype)
132+
return xp.zeros(shape, dtype=dtype, **kwargs)
123133

124134
def zeros_like(
125-
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
135+
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
136+
**kwargs,
126137
) -> ndarray:
127138
_check_device(xp, device)
128-
return xp.zeros_like(x, dtype=dtype)
139+
return xp.zeros_like(x, dtype=dtype, **kwargs)
129140

130141
# np.unique() is split into four functions in the array API:
131142
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
@@ -219,8 +230,9 @@ def std(
219230
axis: Optional[Union[int, Tuple[int, ...]]] = None,
220231
correction: Union[int, float] = 0.0, # correction instead of ddof
221232
keepdims: bool = False,
233+
**kwargs,
222234
) -> ndarray:
223-
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims)
235+
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
224236

225237
def var(
226238
x: ndarray,
@@ -230,8 +242,9 @@ def var(
230242
axis: Optional[Union[int, Tuple[int, ...]]] = None,
231243
correction: Union[int, float] = 0.0, # correction instead of ddof
232244
keepdims: bool = False,
245+
**kwargs,
233246
) -> ndarray:
234-
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims)
247+
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
235248

236249
# Unlike transpose(), the axes argument to permute_dims() is required.
237250
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
@@ -255,6 +268,7 @@ def _asarray(
255268
device: Optional[Device] = None,
256269
copy: "Optional[Union[bool, np._CopyMode]]" = None,
257270
namespace = None,
271+
**kwargs,
258272
) -> ndarray:
259273
"""
260274
Array API compatibility wrapper for asarray().
@@ -296,33 +310,39 @@ def _asarray(
296310
return xp.array(obj, copy=True, dtype=dtype)
297311
return obj
298312

299-
return xp.asarray(obj, dtype=dtype)
313+
return xp.asarray(obj, dtype=dtype, **kwargs)
300314

301315
# xp.reshape calls the keyword argument 'newshape' instead of 'shape'
302-
def reshape(x: ndarray, /, shape: Tuple[int, ...], xp, copy: Optional[bool] = None) -> ndarray:
316+
def reshape(x: ndarray,
317+
/,
318+
shape: Tuple[int, ...],
319+
xp, copy: Optional[bool] = None,
320+
**kwargs) -> ndarray:
303321
if copy is True:
304322
x = x.copy()
305323
elif copy is False:
306324
x.shape = shape
307325
return x
308-
return xp.reshape(x, shape)
326+
return xp.reshape(x, shape, **kwargs)
309327

310328
# The descending keyword is new in sort and argsort, and 'kind' replaced with
311329
# 'stable'
312330
def argsort(
313-
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True
331+
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
332+
**kwargs,
314333
) -> ndarray:
315334
# Note: this keyword argument is different, and the default is different.
316335
kind = "stable" if stable else "quicksort"
317336
if not descending:
318-
res = xp.argsort(x, axis=axis, kind=kind)
337+
res = xp.argsort(x, axis=axis, kind=kind, **kwargs)
319338
else:
320339
# As NumPy has no native descending sort, we imitate it here. Note that
321340
# simply flipping the results of xp.argsort(x, ...) would not
322341
# respect the relative order like it would in native descending sorts.
323342
res = xp.flip(
324343
xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind),
325344
axis=axis,
345+
**kwargs,
326346
)
327347
# Rely on flip()/argsort() to validate axis
328348
normalised_axis = axis if axis >= 0 else x.ndim + axis
@@ -331,11 +351,12 @@ def argsort(
331351
return res
332352

333353
def sort(
334-
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True
354+
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
355+
**kwargs,
335356
) -> ndarray:
336357
# Note: this keyword argument is different, and the default is different.
337358
kind = "stable" if stable else "quicksort"
338-
res = xp.sort(x, axis=axis, kind=kind)
359+
res = xp.sort(x, axis=axis, kind=kind, **kwargs)
339360
if descending:
340361
res = xp.flip(res, axis=axis)
341362
return res
@@ -349,11 +370,12 @@ def sum(
349370
axis: Optional[Union[int, Tuple[int, ...]]] = None,
350371
dtype: Optional[Dtype] = None,
351372
keepdims: bool = False,
373+
**kwargs,
352374
) -> ndarray:
353375
# `xp.sum` already upcasts integers, but not floats
354376
if dtype is None and x.dtype == xp.float32:
355377
dtype = xp.float64
356-
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
378+
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
357379

358380
def prod(
359381
x: ndarray,
@@ -363,27 +385,28 @@ def prod(
363385
axis: Optional[Union[int, Tuple[int, ...]]] = None,
364386
dtype: Optional[Dtype] = None,
365387
keepdims: bool = False,
388+
**kwargs,
366389
) -> ndarray:
367390
if dtype is None and x.dtype == xp.float32:
368391
dtype = xp.float64
369-
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims)
392+
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
370393

371394
# ceil, floor, and trunc return integers for integer inputs
372395

373-
def ceil(x: ndarray, /, xp) -> ndarray:
396+
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
374397
if xp.issubdtype(x.dtype, xp.integer):
375398
return x
376-
return xp.ceil(x)
399+
return xp.ceil(x, **kwargs)
377400

378-
def floor(x: ndarray, /, xp) -> ndarray:
401+
def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
379402
if xp.issubdtype(x.dtype, xp.integer):
380403
return x
381-
return xp.floor(x)
404+
return xp.floor(x, **kwargs)
382405

383-
def trunc(x: ndarray, /, xp) -> ndarray:
406+
def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
384407
if xp.issubdtype(x.dtype, xp.integer):
385408
return x
386-
return xp.trunc(x)
409+
return xp.trunc(x, **kwargs)
387410

388411
__all__ = ['UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
389412
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',

array_api_compat/common/_linalg.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,24 @@
1010
from .._internal import get_xp
1111

1212
# These are in the main NumPy namespace but not in numpy.linalg
13-
def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
14-
return xp.cross(x1, x2, axis=axis)
13+
def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray:
14+
return xp.cross(x1, x2, axis=axis, **kwargs)
1515

16-
def matmul(x1: ndarray, x2: ndarray, /, xp) -> ndarray:
17-
return xp.matmul(x1, x2)
16+
def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
17+
return xp.matmul(x1, x2, **kwargs)
1818

19-
def outer(x1: ndarray, x2: ndarray, /, xp) -> ndarray:
20-
return xp.outer(x1, x2)
19+
def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
20+
return xp.outer(x1, x2, **kwargs)
2121

22-
def tensordot(x1: ndarray, x2: ndarray, /, xp, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> ndarray:
23-
return xp.tensordot(x1, x2, axes=axes)
22+
def tensordot(x1: ndarray,
23+
x2: ndarray,
24+
/,
25+
xp,
26+
*,
27+
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
28+
**kwargs,
29+
) -> ndarray:
30+
return xp.tensordot(x1, x2, axes=axes, **kwargs)
2431

2532
class EighResult(NamedTuple):
2633
eigenvalues: ndarray
@@ -41,35 +48,41 @@ class SVDResult(NamedTuple):
4148

4249
# These functions are the same as their NumPy counterparts except they return
4350
# a namedtuple.
44-
def eigh(x: ndarray, /, xp) -> EighResult:
45-
return EighResult(*xp.linalg.eigh(x))
51+
def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
52+
return EighResult(*xp.linalg.eigh(x, **kwargs))
4653

47-
def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult:
48-
return QRResult(*xp.linalg.qr(x, mode=mode))
54+
def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
55+
**kwargs) -> QRResult:
56+
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
4957

50-
def slogdet(x: ndarray, /, xp) -> SlogdetResult:
51-
return SlogdetResult(*xp.linalg.slogdet(x))
58+
def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult:
59+
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
5260

53-
def svd(x: ndarray, /, xp, *, full_matrices: bool = True) -> SVDResult:
54-
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices))
61+
def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult:
62+
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
5563

5664
# These functions have additional keyword arguments
5765

5866
# The upper keyword argument is new from NumPy
59-
def cholesky(x: ndarray, /, xp, *, upper: bool = False) -> ndarray:
60-
L = xp.linalg.cholesky(x)
67+
def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
68+
L = xp.linalg.cholesky(x, **kwargs)
6169
if upper:
6270
return get_xp(xp)(matrix_transpose)(L)
6371
return L
6472

6573
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
6674
# Note that it has a different semantic meaning from tol and rcond.
67-
def matrix_rank(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray:
75+
def matrix_rank(x: ndarray,
76+
/,
77+
xp,
78+
*,
79+
rtol: Optional[Union[float, ndarray]] = None,
80+
**kwargs) -> ndarray:
6881
# this is different from xp.linalg.matrix_rank, which supports 1
6982
# dimensional arrays.
7083
if x.ndim < 2:
7184
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
72-
S = xp.linalg.svd(x, compute_uv=False)
85+
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
7386
if rtol is None:
7487
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
7588
else:
@@ -78,12 +91,12 @@ def matrix_rank(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = No
7891
tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
7992
return xp.count_nonzero(S > tol, axis=-1)
8093

81-
def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray:
94+
def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray:
8295
# this is different from xp.linalg.pinv, which does not multiply the
8396
# default tolerance by max(M, N).
8497
if rtol is None:
8598
rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps
86-
return xp.linalg.pinv(x, rcond=rtol)
99+
return xp.linalg.pinv(x, rcond=rtol, **kwargs)
87100

88101
# These functions are new in the array API spec
89102

@@ -152,11 +165,11 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
152165
# xp.diagonal and xp.trace operate on the first two axes whereas these
153166
# operates on the last two
154167

155-
def diagonal(x: ndarray, /, xp, *, offset: int = 0) -> ndarray:
156-
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1)
168+
def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
169+
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
157170

158-
def trace(x: ndarray, /, xp, *, offset: int = 0) -> ndarray:
159-
return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1))
171+
def trace(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
172+
return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1, **kwargs))
160173

161174
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
162175
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',

0 commit comments

Comments
 (0)