1
1
from __future__ import annotations
2
2
3
3
import math
4
- from typing import Literal , NamedTuple , Optional , Tuple , Union
4
+ from typing import Literal , NamedTuple , cast
5
5
6
6
import numpy as np
7
+
7
8
if np .__version__ [0 ] == "2" :
8
9
from numpy .lib .array_utils import normalize_axis_tuple
9
10
else :
10
11
from numpy .core .numeric import normalize_axis_tuple
11
12
12
- from ._aliases import matmul , matrix_transpose , tensordot , vecdot , isdtype
13
13
from .._internal import get_xp
14
- from ._typing import Array , Namespace
14
+ from ._aliases import isdtype , matmul , matrix_transpose , tensordot , vecdot
15
+ from ._typing import Array , DType , Namespace
16
+
15
17
16
18
# These are in the main NumPy namespace but not in numpy.linalg
17
- def cross (x1 : Array , x2 : Array , / , xp : Namespace , * , axis : int = - 1 , ** kwargs ) -> Array :
19
+ def cross (
20
+ x1 : Array ,
21
+ x2 : Array ,
22
+ / ,
23
+ xp : Namespace ,
24
+ * ,
25
+ axis : int = - 1 ,
26
+ ** kwargs : object ,
27
+ ) -> Array :
18
28
return xp .cross (x1 , x2 , axis = axis , ** kwargs )
19
29
20
- def outer (x1 : Array , x2 : Array , / , xp : Namespace , ** kwargs ) -> Array :
30
+ def outer (x1 : Array , x2 : Array , / , xp : Namespace , ** kwargs : object ) -> Array :
21
31
return xp .outer (x1 , x2 , ** kwargs )
22
32
23
33
class EighResult (NamedTuple ):
@@ -39,46 +49,66 @@ class SVDResult(NamedTuple):
39
49
40
50
# These functions are the same as their NumPy counterparts except they return
41
51
# a namedtuple.
42
- def eigh (x : Array , / , xp : Namespace , ** kwargs ) -> EighResult :
52
+ def eigh (x : Array , / , xp : Namespace , ** kwargs : object ) -> EighResult :
43
53
return EighResult (* xp .linalg .eigh (x , ** kwargs ))
44
54
45
- def qr (x : Array , / , xp : Namespace , * , mode : Literal ['reduced' , 'complete' ] = 'reduced' ,
46
- ** kwargs ) -> QRResult :
55
+ def qr (
56
+ x : Array ,
57
+ / ,
58
+ xp : Namespace ,
59
+ * ,
60
+ mode : Literal ["reduced" , "complete" ] = "reduced" ,
61
+ ** kwargs : object ,
62
+ ) -> QRResult :
47
63
return QRResult (* xp .linalg .qr (x , mode = mode , ** kwargs ))
48
64
49
- def slogdet (x : Array , / , xp : Namespace , ** kwargs ) -> SlogdetResult :
65
+ def slogdet (x : Array , / , xp : Namespace , ** kwargs : object ) -> SlogdetResult :
50
66
return SlogdetResult (* xp .linalg .slogdet (x , ** kwargs ))
51
67
52
68
def svd (
53
- x : Array , / , xp : Namespace , * , full_matrices : bool = True , ** kwargs
69
+ x : Array ,
70
+ / ,
71
+ xp : Namespace ,
72
+ * ,
73
+ full_matrices : bool = True ,
74
+ ** kwargs : object ,
54
75
) -> SVDResult :
55
76
return SVDResult (* xp .linalg .svd (x , full_matrices = full_matrices , ** kwargs ))
56
77
57
78
# These functions have additional keyword arguments
58
79
59
80
# The upper keyword argument is new from NumPy
60
- def cholesky (x : Array , / , xp : Namespace , * , upper : bool = False , ** kwargs ) -> Array :
81
+ def cholesky (
82
+ x : Array ,
83
+ / ,
84
+ xp : Namespace ,
85
+ * ,
86
+ upper : bool = False ,
87
+ ** kwargs : object ,
88
+ ) -> Array :
61
89
L = xp .linalg .cholesky (x , ** kwargs )
62
90
if upper :
63
91
U = get_xp (xp )(matrix_transpose )(L )
64
92
if get_xp (xp )(isdtype )(U .dtype , 'complex floating' ):
65
- U = xp .conj (U )
93
+ U = xp .conj (U ) # pyright: ignore[reportConstantRedefinition]
66
94
return U
67
95
return L
68
96
69
97
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
70
98
# Note that it has a different semantic meaning from tol and rcond.
71
- def matrix_rank (x : Array ,
72
- / ,
73
- xp : Namespace ,
74
- * ,
75
- rtol : Optional [Union [float , Array ]] = None ,
76
- ** kwargs ) -> Array :
99
+ def matrix_rank (
100
+ x : Array ,
101
+ / ,
102
+ xp : Namespace ,
103
+ * ,
104
+ rtol : float | Array | None = None ,
105
+ ** kwargs : object ,
106
+ ) -> Array :
77
107
# this is different from xp.linalg.matrix_rank, which supports 1
78
108
# dimensional arrays.
79
109
if x .ndim < 2 :
80
110
raise xp .linalg .LinAlgError ("1-dimensional array given. Array must be at least two-dimensional" )
81
- S = get_xp (xp )(svdvals )(x , ** kwargs )
111
+ S : Array = get_xp (xp )(svdvals )(x , ** kwargs )
82
112
if rtol is None :
83
113
tol = S .max (axis = - 1 , keepdims = True ) * max (x .shape [- 2 :]) * xp .finfo (S .dtype ).eps
84
114
else :
@@ -88,7 +118,12 @@ def matrix_rank(x: Array,
88
118
return xp .count_nonzero (S > tol , axis = - 1 )
89
119
90
120
def pinv (
91
- x : Array , / , xp : Namespace , * , rtol : Optional [Union [float , Array ]] = None , ** kwargs
121
+ x : Array ,
122
+ / ,
123
+ xp : Namespace ,
124
+ * ,
125
+ rtol : float | Array | None = None ,
126
+ ** kwargs : object ,
92
127
) -> Array :
93
128
# this is different from xp.linalg.pinv, which does not multiply the
94
129
# default tolerance by max(M, N).
@@ -104,23 +139,23 @@ def matrix_norm(
104
139
xp : Namespace ,
105
140
* ,
106
141
keepdims : bool = False ,
107
- ord : Optional [ Union [ int , float , Literal [' fro' , ' nuc' ]]] = ' fro' ,
142
+ ord : float | Literal [" fro" , " nuc" ] | None = " fro" ,
108
143
) -> Array :
109
144
return xp .linalg .norm (x , axis = (- 2 , - 1 ), keepdims = keepdims , ord = ord )
110
145
111
146
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
112
147
# xp.linalg.svd(compute_uv=False).
113
- def svdvals (x : Array , / , xp : Namespace ) -> Union [ Array , Tuple [Array , ...] ]:
148
+ def svdvals (x : Array , / , xp : Namespace ) -> Array | tuple [Array , ...]:
114
149
return xp .linalg .svd (x , compute_uv = False )
115
150
116
151
def vector_norm (
117
152
x : Array ,
118
153
/ ,
119
154
xp : Namespace ,
120
155
* ,
121
- axis : Optional [ Union [ int , Tuple [int , ...]]] = None ,
156
+ axis : int | tuple [int , ...] | None = None ,
122
157
keepdims : bool = False ,
123
- ord : Optional [ Union [ int , float ]] = 2 ,
158
+ ord : float = 2 ,
124
159
) -> Array :
125
160
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
126
161
# when axis=None and the input is 2-D, so to force a vector norm, we make
@@ -133,7 +168,10 @@ def vector_norm(
133
168
elif isinstance (axis , tuple ):
134
169
# Note: The axis argument supports any number of axes, whereas
135
170
# xp.linalg.norm() only supports a single axis for vector norm.
136
- normalized_axis = normalize_axis_tuple (axis , x .ndim )
171
+ normalized_axis = cast (
172
+ "tuple[int, ...]" ,
173
+ normalize_axis_tuple (axis , x .ndim ), # pyright: ignore[reportCallIssue]
174
+ )
137
175
rest = tuple (i for i in range (x .ndim ) if i not in normalized_axis )
138
176
newshape = axis + rest
139
177
_x = xp .transpose (x , newshape ).reshape (
@@ -149,7 +187,13 @@ def vector_norm(
149
187
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
150
188
# above to avoid matrix norm logic.
151
189
shape = list (x .shape )
152
- _axis = normalize_axis_tuple (range (x .ndim ) if axis is None else axis , x .ndim )
190
+ _axis = cast (
191
+ "tuple[int, ...]" ,
192
+ normalize_axis_tuple ( # pyright: ignore[reportCallIssue]
193
+ range (x .ndim ) if axis is None else axis ,
194
+ x .ndim ,
195
+ ),
196
+ )
153
197
for i in _axis :
154
198
shape [i ] = 1
155
199
res = xp .reshape (res , tuple (shape ))
@@ -159,11 +203,17 @@ def vector_norm(
159
203
# xp.diagonal and xp.trace operate on the first two axes whereas these
160
204
# operates on the last two
161
205
162
- def diagonal (x : Array , / , xp : Namespace , * , offset : int = 0 , ** kwargs ) -> Array :
206
+ def diagonal (x : Array , / , xp : Namespace , * , offset : int = 0 , ** kwargs : object ) -> Array :
163
207
return xp .diagonal (x , offset = offset , axis1 = - 2 , axis2 = - 1 , ** kwargs )
164
208
165
209
def trace (
166
- x : Array , / , xp : Namespace , * , offset : int = 0 , dtype = None , ** kwargs
210
+ x : Array ,
211
+ / ,
212
+ xp : Namespace ,
213
+ * ,
214
+ offset : int = 0 ,
215
+ dtype : DType | None = None ,
216
+ ** kwargs : object ,
167
217
) -> Array :
168
218
return xp .asarray (
169
219
xp .trace (x , offset = offset , dtype = dtype , axis1 = - 2 , axis2 = - 1 , ** kwargs )
0 commit comments