Skip to content

Commit cbec5f3

Browse files
committed
TYP: fix typing errors in common._linalg
1 parent 344ac1e commit cbec5f3

File tree

1 file changed

+78
-28
lines changed

1 file changed

+78
-28
lines changed

array_api_compat/common/_linalg.py

Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
from __future__ import annotations
22

33
import math
4-
from typing import Literal, NamedTuple, Optional, Tuple, Union
4+
from typing import Literal, NamedTuple, cast
55

66
import numpy as np
7+
78
if np.__version__[0] == "2":
89
from numpy.lib.array_utils import normalize_axis_tuple
910
else:
1011
from numpy.core.numeric import normalize_axis_tuple
1112

12-
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
1313
from .._internal import get_xp
14-
from ._typing import Array, Namespace
14+
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
15+
from ._typing import Array, DType, Namespace
16+
1517

1618
# These are in the main NumPy namespace but not in numpy.linalg
17-
def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array:
19+
def cross(
20+
x1: Array,
21+
x2: Array,
22+
/,
23+
xp: Namespace,
24+
*,
25+
axis: int = -1,
26+
**kwargs: object,
27+
) -> Array:
1828
return xp.cross(x1, x2, axis=axis, **kwargs)
1929

20-
def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array:
30+
def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array:
2131
return xp.outer(x1, x2, **kwargs)
2232

2333
class EighResult(NamedTuple):
@@ -39,46 +49,66 @@ class SVDResult(NamedTuple):
3949

4050
# These functions are the same as their NumPy counterparts except they return
4151
# a namedtuple.
42-
def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult:
52+
def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult:
4353
return EighResult(*xp.linalg.eigh(x, **kwargs))
4454

45-
def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced',
46-
**kwargs) -> QRResult:
55+
def qr(
56+
x: Array,
57+
/,
58+
xp: Namespace,
59+
*,
60+
mode: Literal["reduced", "complete"] = "reduced",
61+
**kwargs: object,
62+
) -> QRResult:
4763
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
4864

49-
def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult:
65+
def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult:
5066
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
5167

5268
def svd(
53-
x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs
69+
x: Array,
70+
/,
71+
xp: Namespace,
72+
*,
73+
full_matrices: bool = True,
74+
**kwargs: object,
5475
) -> SVDResult:
5576
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
5677

5778
# These functions have additional keyword arguments
5879

5980
# The upper keyword argument is new from NumPy
60-
def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array:
81+
def cholesky(
82+
x: Array,
83+
/,
84+
xp: Namespace,
85+
*,
86+
upper: bool = False,
87+
**kwargs: object,
88+
) -> Array:
6189
L = xp.linalg.cholesky(x, **kwargs)
6290
if upper:
6391
U = get_xp(xp)(matrix_transpose)(L)
6492
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
65-
U = xp.conj(U)
93+
U = xp.conj(U) # pyright: ignore[reportConstantRedefinition]
6694
return U
6795
return L
6896

6997
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
7098
# Note that it has a different semantic meaning from tol and rcond.
71-
def matrix_rank(x: Array,
72-
/,
73-
xp: Namespace,
74-
*,
75-
rtol: Optional[Union[float, Array]] = None,
76-
**kwargs) -> Array:
99+
def matrix_rank(
100+
x: Array,
101+
/,
102+
xp: Namespace,
103+
*,
104+
rtol: float | Array | None = None,
105+
**kwargs: object,
106+
) -> Array:
77107
# this is different from xp.linalg.matrix_rank, which supports 1
78108
# dimensional arrays.
79109
if x.ndim < 2:
80110
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
81-
S = get_xp(xp)(svdvals)(x, **kwargs)
111+
S: Array = get_xp(xp)(svdvals)(x, **kwargs)
82112
if rtol is None:
83113
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
84114
else:
@@ -88,7 +118,12 @@ def matrix_rank(x: Array,
88118
return xp.count_nonzero(S > tol, axis=-1)
89119

90120
def pinv(
91-
x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs
121+
x: Array,
122+
/,
123+
xp: Namespace,
124+
*,
125+
rtol: float | Array | None = None,
126+
**kwargs: object,
92127
) -> Array:
93128
# this is different from xp.linalg.pinv, which does not multiply the
94129
# default tolerance by max(M, N).
@@ -104,23 +139,23 @@ def matrix_norm(
104139
xp: Namespace,
105140
*,
106141
keepdims: bool = False,
107-
ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro',
142+
ord: float | Literal["fro", "nuc"] | None = "fro",
108143
) -> Array:
109144
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
110145

111146
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
112147
# xp.linalg.svd(compute_uv=False).
113-
def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]:
148+
def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]:
114149
return xp.linalg.svd(x, compute_uv=False)
115150

116151
def vector_norm(
117152
x: Array,
118153
/,
119154
xp: Namespace,
120155
*,
121-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
156+
axis: int | tuple[int, ...] | None = None,
122157
keepdims: bool = False,
123-
ord: Optional[Union[int, float]] = 2,
158+
ord: float = 2,
124159
) -> Array:
125160
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
126161
# when axis=None and the input is 2-D, so to force a vector norm, we make
@@ -133,7 +168,10 @@ def vector_norm(
133168
elif isinstance(axis, tuple):
134169
# Note: The axis argument supports any number of axes, whereas
135170
# xp.linalg.norm() only supports a single axis for vector norm.
136-
normalized_axis = normalize_axis_tuple(axis, x.ndim)
171+
normalized_axis = cast(
172+
"tuple[int, ...]",
173+
normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue]
174+
)
137175
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
138176
newshape = axis + rest
139177
_x = xp.transpose(x, newshape).reshape(
@@ -149,7 +187,13 @@ def vector_norm(
149187
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
150188
# above to avoid matrix norm logic.
151189
shape = list(x.shape)
152-
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
190+
_axis = cast(
191+
"tuple[int, ...]",
192+
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
193+
range(x.ndim) if axis is None else axis,
194+
x.ndim,
195+
),
196+
)
153197
for i in _axis:
154198
shape[i] = 1
155199
res = xp.reshape(res, tuple(shape))
@@ -159,11 +203,17 @@ def vector_norm(
159203
# xp.diagonal and xp.trace operate on the first two axes whereas these
160204
# operates on the last two
161205

162-
def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array:
206+
def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array:
163207
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
164208

165209
def trace(
166-
x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs
210+
x: Array,
211+
/,
212+
xp: Namespace,
213+
*,
214+
offset: int = 0,
215+
dtype: DType | None = None,
216+
**kwargs: object,
167217
) -> Array:
168218
return xp.asarray(
169219
xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)

0 commit comments

Comments
 (0)