Skip to content

Commit 924d297

Browse files
committed
wrap svd
1 parent 3f06837 commit 924d297

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
if TYPE_CHECKING:
1313
from typing import Optional, Tuple, Union
1414

15-
from ...common._typing import Device, Dtype, ndarray
15+
from ...common._typing import Device, Dtype, Array
1616

1717
import dask.array as da
1818

@@ -37,7 +37,7 @@ def dask_arange(
3737
dtype: Optional[Dtype] = None,
3838
device: Optional[Device] = None,
3939
**kwargs,
40-
) -> ndarray:
40+
) -> Array:
4141
_check_device(xp, device)
4242
args = [start]
4343
if stop is not None:
@@ -99,8 +99,18 @@ def dask_arange(
9999
matrix_rank = get_xp(da)(_linalg.matrix_rank)
100100
matrix_norm = get_xp(da)(_linalg.matrix_norm)
101101

102+
# Wrap the svd functions to not pass full_matrices to dask
103+
# when full_matrices=False (as that is the defualt behavior for dask),
104+
# and dask doesn't have the full_matrices keyword
105+
_svd = get_xp(da)(_linalg.svd)
102106

103-
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
107+
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
108+
if full_matrices:
109+
return _svd(x, full_matrices=full_matrices, **kwargs)
110+
return _svd(x, **kwargs)
111+
112+
113+
def svdvals(x: Array) -> Array:
104114
# TODO: can't avoid computing U or V for dask
105115
_, s, _ = da.linalg.svd(x)
106116
return s

dask-xfails.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ array_api_tests/test_linalg.py::test_solve
112112
# missing full_matrics kw
113113
# https://github.com/dask/dask/issues/10389
114114
# also only supports 2-d inputs
115-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd]
116115
array_api_tests/test_linalg.py::test_svd
117116

118117
# Missing dlpack stuff

0 commit comments

Comments
 (0)