Skip to content

Commit d6bf189

Browse files
committed
Add wrapper for torch.linalg.solve that does correct type promotion
1 parent dfc1275 commit d6bf189

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

array_api_compat/torch/linalg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
4545
return res[..., 0, 0]
4646
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
4747

48-
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', 'vecdot']
48+
def solve(x1: array, x2: array, /, **kwargs) -> array:
49+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
50+
return torch.linalg.solve(x1, x2, **kwargs)
51+
52+
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
53+
'vecdot', 'solve']
4954

5055
del linalg_all

0 commit comments

Comments
 (0)