12
12
if TYPE_CHECKING :
13
13
from typing import Optional , Tuple , Union
14
14
15
- from ...common ._typing import Device , Dtype , ndarray
15
+ from ...common ._typing import Device , Dtype , Array
16
16
17
17
import dask .array as da
18
18
@@ -37,7 +37,7 @@ def dask_arange(
37
37
dtype : Optional [Dtype ] = None ,
38
38
device : Optional [Device ] = None ,
39
39
** kwargs ,
40
- ) -> ndarray :
40
+ ) -> Array :
41
41
_check_device (xp , device )
42
42
args = [start ]
43
43
if stop is not None :
@@ -99,8 +99,18 @@ def dask_arange(
99
99
matrix_rank = get_xp (da )(_linalg .matrix_rank )
100
100
matrix_norm = get_xp (da )(_linalg .matrix_norm )
101
101
102
+ # Wrap the svd functions to not pass full_matrices to dask
103
+ # when full_matrices=False (as that is the defualt behavior for dask),
104
+ # and dask doesn't have the full_matrices keyword
105
+ _svd = get_xp (da )(_linalg .svd )
102
106
103
- def svdvals (x : ndarray ) -> Union [ndarray , Tuple [ndarray , ...]]:
107
+ def svd (x : Array , full_matrices : bool = True , ** kwargs ) -> SVDResult :
108
+ if full_matrices :
109
+ return _svd (x , full_matrices = full_matrices , ** kwargs )
110
+ return _svd (x , ** kwargs )
111
+
112
+
113
+ def svdvals (x : Array ) -> Array :
104
114
# TODO: can't avoid computing U or V for dask
105
115
_ , s , _ = da .linalg .svd (x )
106
116
return s
0 commit comments