@@ -25,25 +25,28 @@ def arange(
25
25
xp ,
26
26
dtype : Optional [Dtype ] = None ,
27
27
device : Optional [Device ] = None ,
28
+ ** kwargs
28
29
) -> ndarray :
29
30
_check_device (xp , device )
30
- return xp .arange (start , stop = stop , step = step , dtype = dtype )
31
+ return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
31
32
32
33
def empty (
33
34
shape : Union [int , Tuple [int , ...]],
34
35
xp ,
35
36
* ,
36
37
dtype : Optional [Dtype ] = None ,
37
38
device : Optional [Device ] = None ,
39
+ ** kwargs
38
40
) -> ndarray :
39
41
_check_device (xp , device )
40
- return xp .empty (shape , dtype = dtype )
42
+ return xp .empty (shape , dtype = dtype , ** kwargs )
41
43
42
44
def empty_like (
43
- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
45
+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
46
+ ** kwargs
44
47
) -> ndarray :
45
48
_check_device (xp , device )
46
- return xp .empty_like (x , dtype = dtype )
49
+ return xp .empty_like (x , dtype = dtype , ** kwargs )
47
50
48
51
def eye (
49
52
n_rows : int ,
@@ -54,9 +57,10 @@ def eye(
54
57
k : int = 0 ,
55
58
dtype : Optional [Dtype ] = None ,
56
59
device : Optional [Device ] = None ,
60
+ ** kwargs ,
57
61
) -> ndarray :
58
62
_check_device (xp , device )
59
- return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype )
63
+ return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
60
64
61
65
def full (
62
66
shape : Union [int , Tuple [int , ...]],
@@ -65,9 +69,10 @@ def full(
65
69
* ,
66
70
dtype : Optional [Dtype ] = None ,
67
71
device : Optional [Device ] = None ,
72
+ ** kwargs ,
68
73
) -> ndarray :
69
74
_check_device (xp , device )
70
- return xp .full (shape , fill_value , dtype = dtype )
75
+ return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
71
76
72
77
def full_like (
73
78
x : ndarray ,
@@ -77,9 +82,10 @@ def full_like(
77
82
xp ,
78
83
dtype : Optional [Dtype ] = None ,
79
84
device : Optional [Device ] = None ,
85
+ ** kwargs ,
80
86
) -> ndarray :
81
87
_check_device (xp , device )
82
- return xp .full_like (x , fill_value , dtype = dtype )
88
+ return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
83
89
84
90
def linspace (
85
91
start : Union [int , float ],
@@ -91,41 +97,46 @@ def linspace(
91
97
dtype : Optional [Dtype ] = None ,
92
98
device : Optional [Device ] = None ,
93
99
endpoint : bool = True ,
100
+ ** kwargs ,
94
101
) -> ndarray :
95
102
_check_device (xp , device )
96
- return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint )
103
+ return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
97
104
98
105
def ones (
99
106
shape : Union [int , Tuple [int , ...]],
100
107
xp ,
101
108
* ,
102
109
dtype : Optional [Dtype ] = None ,
103
110
device : Optional [Device ] = None ,
111
+ ** kwargs ,
104
112
) -> ndarray :
105
113
_check_device (xp , device )
106
- return xp .ones (shape , dtype = dtype )
114
+ return xp .ones (shape , dtype = dtype , ** kwargs )
107
115
108
116
def ones_like (
109
- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
117
+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
118
+ ** kwargs ,
110
119
) -> ndarray :
111
120
_check_device (xp , device )
112
- return xp .ones_like (x , dtype = dtype )
121
+ return xp .ones_like (x , dtype = dtype , ** kwargs )
113
122
114
123
def zeros (
115
124
shape : Union [int , Tuple [int , ...]],
116
125
xp ,
117
126
* ,
118
127
dtype : Optional [Dtype ] = None ,
119
128
device : Optional [Device ] = None ,
129
+ ** kwargs ,
120
130
) -> ndarray :
121
131
_check_device (xp , device )
122
- return xp .zeros (shape , dtype = dtype )
132
+ return xp .zeros (shape , dtype = dtype , ** kwargs )
123
133
124
134
def zeros_like (
125
- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
135
+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
136
+ ** kwargs ,
126
137
) -> ndarray :
127
138
_check_device (xp , device )
128
- return xp .zeros_like (x , dtype = dtype )
139
+ return xp .zeros_like (x , dtype = dtype , ** kwargs )
129
140
130
141
# np.unique() is split into four functions in the array API:
131
142
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
@@ -219,8 +230,9 @@ def std(
219
230
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
220
231
correction : Union [int , float ] = 0.0 , # correction instead of ddof
221
232
keepdims : bool = False ,
233
+ ** kwargs ,
222
234
) -> ndarray :
223
- return xp .std (x , axis = axis , ddof = correction , keepdims = keepdims )
235
+ return xp .std (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
224
236
225
237
def var (
226
238
x : ndarray ,
@@ -230,8 +242,9 @@ def var(
230
242
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
231
243
correction : Union [int , float ] = 0.0 , # correction instead of ddof
232
244
keepdims : bool = False ,
245
+ ** kwargs ,
233
246
) -> ndarray :
234
- return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims )
247
+ return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
235
248
236
249
# Unlike transpose(), the axes argument to permute_dims() is required.
237
250
def permute_dims (x : ndarray , / , axes : Tuple [int , ...], xp ) -> ndarray :
@@ -255,6 +268,7 @@ def _asarray(
255
268
device : Optional [Device ] = None ,
256
269
copy : "Optional[Union[bool, np._CopyMode]]" = None ,
257
270
namespace = None ,
271
+ ** kwargs ,
258
272
) -> ndarray :
259
273
"""
260
274
Array API compatibility wrapper for asarray().
@@ -296,33 +310,39 @@ def _asarray(
296
310
return xp .array (obj , copy = True , dtype = dtype )
297
311
return obj
298
312
299
- return xp .asarray (obj , dtype = dtype )
313
+ return xp .asarray (obj , dtype = dtype , ** kwargs )
300
314
301
315
# xp.reshape calls the keyword argument 'newshape' instead of 'shape'
302
- def reshape (x : ndarray , / , shape : Tuple [int , ...], xp , copy : Optional [bool ] = None ) -> ndarray :
316
+ def reshape (x : ndarray ,
317
+ / ,
318
+ shape : Tuple [int , ...],
319
+ xp , copy : Optional [bool ] = None ,
320
+ ** kwargs ) -> ndarray :
303
321
if copy is True :
304
322
x = x .copy ()
305
323
elif copy is False :
306
324
x .shape = shape
307
325
return x
308
- return xp .reshape (x , shape )
326
+ return xp .reshape (x , shape , ** kwargs )
309
327
310
328
# The descending keyword is new in sort and argsort, and 'kind' replaced with
311
329
# 'stable'
312
330
def argsort (
313
- x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True
331
+ x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True ,
332
+ ** kwargs ,
314
333
) -> ndarray :
315
334
# Note: this keyword argument is different, and the default is different.
316
335
kind = "stable" if stable else "quicksort"
317
336
if not descending :
318
- res = xp .argsort (x , axis = axis , kind = kind )
337
+ res = xp .argsort (x , axis = axis , kind = kind , ** kwargs )
319
338
else :
320
339
# As NumPy has no native descending sort, we imitate it here. Note that
321
340
# simply flipping the results of xp.argsort(x, ...) would not
322
341
# respect the relative order like it would in native descending sorts.
323
342
res = xp .flip (
324
343
xp .argsort (xp .flip (x , axis = axis ), axis = axis , kind = kind ),
325
344
axis = axis ,
345
+ ** kwargs ,
326
346
)
327
347
# Rely on flip()/argsort() to validate axis
328
348
normalised_axis = axis if axis >= 0 else x .ndim + axis
@@ -331,11 +351,12 @@ def argsort(
331
351
return res
332
352
333
353
def sort (
334
- x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True
354
+ x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True ,
355
+ ** kwargs ,
335
356
) -> ndarray :
336
357
# Note: this keyword argument is different, and the default is different.
337
358
kind = "stable" if stable else "quicksort"
338
- res = xp .sort (x , axis = axis , kind = kind )
359
+ res = xp .sort (x , axis = axis , kind = kind , ** kwargs )
339
360
if descending :
340
361
res = xp .flip (res , axis = axis )
341
362
return res
@@ -349,11 +370,12 @@ def sum(
349
370
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
350
371
dtype : Optional [Dtype ] = None ,
351
372
keepdims : bool = False ,
373
+ ** kwargs ,
352
374
) -> ndarray :
353
375
# `xp.sum` already upcasts integers, but not floats
354
376
if dtype is None and x .dtype == xp .float32 :
355
377
dtype = xp .float64
356
- return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims )
378
+ return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims , ** kwargs )
357
379
358
380
def prod (
359
381
x : ndarray ,
@@ -363,27 +385,28 @@ def prod(
363
385
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
364
386
dtype : Optional [Dtype ] = None ,
365
387
keepdims : bool = False ,
388
+ ** kwargs ,
366
389
) -> ndarray :
367
390
if dtype is None and x .dtype == xp .float32 :
368
391
dtype = xp .float64
369
- return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims )
392
+ return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims , ** kwargs )
370
393
371
394
# ceil, floor, and trunc return integers for integer inputs
372
395
373
- def ceil (x : ndarray , / , xp ) -> ndarray :
396
+ def ceil (x : ndarray , / , xp , ** kwargs ) -> ndarray :
374
397
if xp .issubdtype (x .dtype , xp .integer ):
375
398
return x
376
- return xp .ceil (x )
399
+ return xp .ceil (x , ** kwargs )
377
400
378
- def floor (x : ndarray , / , xp ) -> ndarray :
401
+ def floor (x : ndarray , / , xp , ** kwargs ) -> ndarray :
379
402
if xp .issubdtype (x .dtype , xp .integer ):
380
403
return x
381
- return xp .floor (x )
404
+ return xp .floor (x , ** kwargs )
382
405
383
- def trunc (x : ndarray , / , xp ) -> ndarray :
406
+ def trunc (x : ndarray , / , xp , ** kwargs ) -> ndarray :
384
407
if xp .issubdtype (x .dtype , xp .integer ):
385
408
return x
386
- return xp .trunc (x )
409
+ return xp .trunc (x , ** kwargs )
387
410
388
411
__all__ = ['UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
389
412
'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
0 commit comments