14
14
15
15
from ._helpers import _check_device , _is_numpy_array , get_namespace
16
16
17
- # Basic renames
18
- def acos (x , / , xp ):
19
- return xp .arccos (x )
20
-
21
- def acosh (x , / , xp ):
22
- return xp .arccosh (x )
23
-
24
- def asin (x , / , xp ):
25
- return xp .arcsin (x )
26
-
27
- def asinh (x , / , xp ):
28
- return xp .arcsinh (x )
29
-
30
- def atan (x , / , xp ):
31
- return xp .arctan (x )
32
-
33
- def atan2 (x1 , x2 , / , xp ):
34
- return xp .arctan2 (x1 , x2 )
35
-
36
- def atanh (x , / , xp ):
37
- return xp .arctanh (x )
38
-
39
- def bitwise_left_shift (x1 , x2 , / , xp ):
40
- return xp .left_shift (x1 , x2 )
41
-
42
- def bitwise_invert (x , / , xp ):
43
- return xp .invert (x )
44
-
45
- def bitwise_right_shift (x1 , x2 , / , xp ):
46
- return xp .right_shift (x1 , x2 )
47
-
48
- def concat (arrays : Union [Tuple [ndarray , ...], List [ndarray ]], / , xp , * , axis : Optional [int ] = 0 ) -> ndarray :
49
- return xp .concatenate (arrays , axis = axis )
50
-
51
- def pow (x1 , x2 , / , xp ):
52
- return xp .power (x1 , x2 )
53
-
54
17
# These functions are modified from the NumPy versions.
55
18
56
19
def arange (
@@ -62,25 +25,28 @@ def arange(
62
25
xp ,
63
26
dtype : Optional [Dtype ] = None ,
64
27
device : Optional [Device ] = None ,
28
+ ** kwargs
65
29
) -> ndarray :
66
30
_check_device (xp , device )
67
- return xp .arange (start , stop = stop , step = step , dtype = dtype )
31
+ return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
68
32
69
33
def empty (
70
34
shape : Union [int , Tuple [int , ...]],
71
35
xp ,
72
36
* ,
73
37
dtype : Optional [Dtype ] = None ,
74
38
device : Optional [Device ] = None ,
39
+ ** kwargs
75
40
) -> ndarray :
76
41
_check_device (xp , device )
77
- return xp .empty (shape , dtype = dtype )
42
+ return xp .empty (shape , dtype = dtype , ** kwargs )
78
43
79
44
def empty_like (
80
- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
45
+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
46
+ ** kwargs
81
47
) -> ndarray :
82
48
_check_device (xp , device )
83
- return xp .empty_like (x , dtype = dtype )
49
+ return xp .empty_like (x , dtype = dtype , ** kwargs )
84
50
85
51
def eye (
86
52
n_rows : int ,
@@ -91,9 +57,10 @@ def eye(
91
57
k : int = 0 ,
92
58
dtype : Optional [Dtype ] = None ,
93
59
device : Optional [Device ] = None ,
60
+ ** kwargs ,
94
61
) -> ndarray :
95
62
_check_device (xp , device )
96
- return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype )
63
+ return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
97
64
98
65
def full (
99
66
shape : Union [int , Tuple [int , ...]],
@@ -102,9 +69,10 @@ def full(
102
69
* ,
103
70
dtype : Optional [Dtype ] = None ,
104
71
device : Optional [Device ] = None ,
72
+ ** kwargs ,
105
73
) -> ndarray :
106
74
_check_device (xp , device )
107
- return xp .full (shape , fill_value , dtype = dtype )
75
+ return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
108
76
109
77
def full_like (
110
78
x : ndarray ,
@@ -114,9 +82,10 @@ def full_like(
114
82
xp ,
115
83
dtype : Optional [Dtype ] = None ,
116
84
device : Optional [Device ] = None ,
85
+ ** kwargs ,
117
86
) -> ndarray :
118
87
_check_device (xp , device )
119
- return xp .full_like (x , fill_value , dtype = dtype )
88
+ return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
120
89
121
90
def linspace (
122
91
start : Union [int , float ],
@@ -128,41 +97,46 @@ def linspace(
128
97
dtype : Optional [Dtype ] = None ,
129
98
device : Optional [Device ] = None ,
130
99
endpoint : bool = True ,
100
+ ** kwargs ,
131
101
) -> ndarray :
132
102
_check_device (xp , device )
133
- return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint )
103
+ return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
134
104
135
105
def ones (
136
106
shape : Union [int , Tuple [int , ...]],
137
107
xp ,
138
108
* ,
139
109
dtype : Optional [Dtype ] = None ,
140
110
device : Optional [Device ] = None ,
111
+ ** kwargs ,
141
112
) -> ndarray :
142
113
_check_device (xp , device )
143
- return xp .ones (shape , dtype = dtype )
114
+ return xp .ones (shape , dtype = dtype , ** kwargs )
144
115
145
116
def ones_like (
146
- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
117
+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
118
+ ** kwargs ,
147
119
) -> ndarray :
148
120
_check_device (xp , device )
149
- return xp .ones_like (x , dtype = dtype )
121
+ return xp .ones_like (x , dtype = dtype , ** kwargs )
150
122
151
123
def zeros (
152
124
shape : Union [int , Tuple [int , ...]],
153
125
xp ,
154
126
* ,
155
127
dtype : Optional [Dtype ] = None ,
156
128
device : Optional [Device ] = None ,
129
+ ** kwargs ,
157
130
) -> ndarray :
158
131
_check_device (xp , device )
159
- return xp .zeros (shape , dtype = dtype )
132
+ return xp .zeros (shape , dtype = dtype , ** kwargs )
160
133
161
134
def zeros_like (
162
- x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
135
+ x : ndarray , / , xp , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ,
136
+ ** kwargs ,
163
137
) -> ndarray :
164
138
_check_device (xp , device )
165
- return xp .zeros_like (x , dtype = dtype )
139
+ return xp .zeros_like (x , dtype = dtype , ** kwargs )
166
140
167
141
# np.unique() is split into four functions in the array API:
168
142
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
@@ -256,8 +230,9 @@ def std(
256
230
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
257
231
correction : Union [int , float ] = 0.0 , # correction instead of ddof
258
232
keepdims : bool = False ,
233
+ ** kwargs ,
259
234
) -> ndarray :
260
- return xp .std (x , axis = axis , ddof = correction , keepdims = keepdims )
235
+ return xp .std (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
261
236
262
237
def var (
263
238
x : ndarray ,
@@ -267,8 +242,9 @@ def var(
267
242
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
268
243
correction : Union [int , float ] = 0.0 , # correction instead of ddof
269
244
keepdims : bool = False ,
245
+ ** kwargs ,
270
246
) -> ndarray :
271
- return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims )
247
+ return xp .var (x , axis = axis , ddof = correction , keepdims = keepdims , ** kwargs )
272
248
273
249
# Unlike transpose(), the axes argument to permute_dims() is required.
274
250
def permute_dims (x : ndarray , / , axes : Tuple [int , ...], xp ) -> ndarray :
@@ -292,6 +268,7 @@ def _asarray(
292
268
device : Optional [Device ] = None ,
293
269
copy : "Optional[Union[bool, np._CopyMode]]" = None ,
294
270
namespace = None ,
271
+ ** kwargs ,
295
272
) -> ndarray :
296
273
"""
297
274
Array API compatibility wrapper for asarray().
@@ -333,33 +310,39 @@ def _asarray(
333
310
return xp .array (obj , copy = True , dtype = dtype )
334
311
return obj
335
312
336
- return xp .asarray (obj , dtype = dtype )
313
+ return xp .asarray (obj , dtype = dtype , ** kwargs )
337
314
338
315
# xp.reshape calls the keyword argument 'newshape' instead of 'shape'
339
- def reshape (x : ndarray , / , shape : Tuple [int , ...], xp , copy : Optional [bool ] = None ) -> ndarray :
316
+ def reshape (x : ndarray ,
317
+ / ,
318
+ shape : Tuple [int , ...],
319
+ xp , copy : Optional [bool ] = None ,
320
+ ** kwargs ) -> ndarray :
340
321
if copy is True :
341
322
x = x .copy ()
342
323
elif copy is False :
343
324
x .shape = shape
344
325
return x
345
- return xp .reshape (x , shape )
326
+ return xp .reshape (x , shape , ** kwargs )
346
327
347
328
# The descending keyword is new in sort and argsort, and 'kind' replaced with
348
329
# 'stable'
349
330
def argsort (
350
- x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True
331
+ x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True ,
332
+ ** kwargs ,
351
333
) -> ndarray :
352
334
# Note: this keyword argument is different, and the default is different.
353
335
kind = "stable" if stable else "quicksort"
354
336
if not descending :
355
- res = xp .argsort (x , axis = axis , kind = kind )
337
+ res = xp .argsort (x , axis = axis , kind = kind , ** kwargs )
356
338
else :
357
339
# As NumPy has no native descending sort, we imitate it here. Note that
358
340
# simply flipping the results of xp.argsort(x, ...) would not
359
341
# respect the relative order like it would in native descending sorts.
360
342
res = xp .flip (
361
343
xp .argsort (xp .flip (x , axis = axis ), axis = axis , kind = kind ),
362
344
axis = axis ,
345
+ ** kwargs ,
363
346
)
364
347
# Rely on flip()/argsort() to validate axis
365
348
normalised_axis = axis if axis >= 0 else x .ndim + axis
@@ -368,11 +351,12 @@ def argsort(
368
351
return res
369
352
370
353
def sort (
371
- x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True
354
+ x : ndarray , / , xp , * , axis : int = - 1 , descending : bool = False , stable : bool = True ,
355
+ ** kwargs ,
372
356
) -> ndarray :
373
357
# Note: this keyword argument is different, and the default is different.
374
358
kind = "stable" if stable else "quicksort"
375
- res = xp .sort (x , axis = axis , kind = kind )
359
+ res = xp .sort (x , axis = axis , kind = kind , ** kwargs )
376
360
if descending :
377
361
res = xp .flip (res , axis = axis )
378
362
return res
@@ -386,11 +370,12 @@ def sum(
386
370
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
387
371
dtype : Optional [Dtype ] = None ,
388
372
keepdims : bool = False ,
373
+ ** kwargs ,
389
374
) -> ndarray :
390
375
# `xp.sum` already upcasts integers, but not floats
391
376
if dtype is None and x .dtype == xp .float32 :
392
377
dtype = xp .float64
393
- return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims )
378
+ return xp .sum (x , axis = axis , dtype = dtype , keepdims = keepdims , ** kwargs )
394
379
395
380
def prod (
396
381
x : ndarray ,
@@ -400,32 +385,30 @@ def prod(
400
385
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
401
386
dtype : Optional [Dtype ] = None ,
402
387
keepdims : bool = False ,
388
+ ** kwargs ,
403
389
) -> ndarray :
404
390
if dtype is None and x .dtype == xp .float32 :
405
391
dtype = xp .float64
406
- return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims )
392
+ return xp .prod (x , dtype = dtype , axis = axis , keepdims = keepdims , ** kwargs )
407
393
408
394
# ceil, floor, and trunc return integers for integer inputs
409
395
410
- def ceil (x : ndarray , / , xp ) -> ndarray :
396
+ def ceil (x : ndarray , / , xp , ** kwargs ) -> ndarray :
411
397
if xp .issubdtype (x .dtype , xp .integer ):
412
398
return x
413
- return xp .ceil (x )
399
+ return xp .ceil (x , ** kwargs )
414
400
415
- def floor (x : ndarray , / , xp ) -> ndarray :
401
+ def floor (x : ndarray , / , xp , ** kwargs ) -> ndarray :
416
402
if xp .issubdtype (x .dtype , xp .integer ):
417
403
return x
418
- return xp .floor (x )
404
+ return xp .floor (x , ** kwargs )
419
405
420
- def trunc (x : ndarray , / , xp ) -> ndarray :
406
+ def trunc (x : ndarray , / , xp , ** kwargs ) -> ndarray :
421
407
if xp .issubdtype (x .dtype , xp .integer ):
422
408
return x
423
- return xp .trunc (x )
424
-
425
- __all__ = ['acos' , 'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' , 'atanh' ,
426
- 'bitwise_left_shift' , 'bitwise_invert' , 'bitwise_right_shift' ,
427
- 'concat' , 'pow' , 'UniqueAllResult' , 'UniqueCountsResult' ,
428
- 'UniqueInverseResult' , 'unique_all' , 'unique_counts' ,
429
- 'unique_inverse' , 'unique_values' , 'astype' , 'std' , 'var' ,
430
- 'permute_dims' , 'reshape' , 'argsort' , 'sort' , 'sum' , 'prod' ,
431
- 'ceil' , 'floor' , 'trunc' ]
409
+ return xp .trunc (x , ** kwargs )
410
+
411
+ __all__ = ['UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
412
+ 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
413
+ 'astype' , 'std' , 'var' , 'permute_dims' , 'reshape' , 'argsort' ,
414
+ 'sort' , 'sum' , 'prod' , 'ceil' , 'floor' , 'trunc' ]
0 commit comments