Skip to content

Commit 968e499

Browse files
committed
Add implementation for torch.linalg.vecdot for integer dtypes
1 parent b32a5b3 commit 968e499

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

array_api_compat/torch/linalg.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,27 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
2222
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
2323
return torch_linalg.cross(x1, x2, dim=axis)
2424

25-
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot']
25+
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
26+
from ._aliases import isdtype
27+
28+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
29+
30+
# torch.linalg.vecdot doesn't support integer dtypes
31+
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
32+
ndim = max(x1.ndim, x2.ndim)
33+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
34+
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
35+
if x1_shape[axis] != x2_shape[axis]:
36+
raise ValueError("x1 and x2 must have the same size along the given axis")
37+
38+
x1_, x2_ = torch.broadcast_tensors(x1, x2)
39+
x1_ = torch.moveaxis(x1_, axis, -1)
40+
x2_ = torch.moveaxis(x2_, axis, -1)
41+
42+
res = x1_[..., None, :] @ x2_[..., None]
43+
return res[..., 0, 0]
44+
return torch.linalg.vecdot(x1, x2, axis=axis)
45+
46+
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot', 'vecdot']
2647

2748
del linalg_all

0 commit comments

Comments
 (0)