Skip to content

Commit cd381a0

Browse files
committed
address more feedback
1 parent 4be5517 commit cd381a0

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

array_api_compat/common/_linalg.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ def matrix_rank(x: ndarray,
7777
# dimensional arrays.
7878
if x.ndim < 2:
7979
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
80-
if hasattr(xp.linalg, "svdvals"):
81-
S = xp.linalg.svdvals(x, **kwargs)
82-
else:
83-
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
80+
S = get_xp(xp).linalg.svdvals(x, **kwargs)
8481
if rtol is None:
8582
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
8683
else:

array_api_compat/dask/array/linalg.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,17 @@
1010

1111
from typing import TYPE_CHECKING
1212
if TYPE_CHECKING:
13-
from typing import Optional, Union, Tuple
14-
from ...common._typing import ndarray, Device, Dtype
13+
from typing import Union, Tuple
14+
from ...common._typing import ndarray
15+
16+
# cupy.linalg doesn't have __all__. If it is added, replace this with
17+
#
18+
# from cupy.linalg import __all__ as linalg_all
19+
_n = {}
20+
exec('from dask.array.linalg import *', _n)
21+
del _n['__builtins__']
22+
linalg_all = list(_n)
23+
del _n
1524

1625
EighResult = _linalg.EighResult
1726
QRResult = _linalg.QRResult
@@ -30,6 +39,10 @@ def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
3039
vector_norm = get_xp(da)(_linalg.vector_norm)
3140
diagonal = get_xp(da)(_linalg.diagonal)
3241

42+
__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult",
43+
"SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm",
44+
"svdvals", "vector_norm", "diagonal"]
45+
3346
del get_xp
3447
del da
3548
del _linalg

0 commit comments

Comments
 (0)