Skip to content

Commit dfc1275

Browse files
committed
Pass kwargs through to torch.linalg.vecdot
1 parent 968e499 commit dfc1275

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

array_api_compat/torch/linalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
2222
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
2323
return torch_linalg.cross(x1, x2, dim=axis)
2424

25-
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
25+
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
2626
from ._aliases import isdtype
2727

2828
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
2929

3030
# torch.linalg.vecdot doesn't support integer dtypes
3131
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
32+
if kwargs:
33+
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
3234
ndim = max(x1.ndim, x2.ndim)
3335
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
3436
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
@@ -41,7 +43,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
4143

4244
res = x1_[..., None, :] @ x2_[..., None]
4345
return res[..., 0, 0]
44-
return torch.linalg.vecdot(x1, x2, axis=axis)
46+
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
4547

4648
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', 'vecdot']
4749

0 commit comments

Comments
 (0)