@@ -22,13 +22,15 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
22
22
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
23
23
return torch_linalg .cross (x1 , x2 , dim = axis )
24
24
25
- def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
25
+ def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
26
26
from ._aliases import isdtype
27
27
28
28
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
29
29
30
30
# torch.linalg.vecdot doesn't support integer dtypes
31
31
if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
32
+ if kwargs :
33
+ raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
32
34
ndim = max (x1 .ndim , x2 .ndim )
33
35
x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
34
36
x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
@@ -41,7 +43,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
41
43
42
44
res = x1_ [..., None , :] @ x2_ [..., None ]
43
45
return res [..., 0 , 0 ]
44
- return torch .linalg .vecdot (x1 , x2 , axis = axis )
46
+ return torch .linalg .vecdot (x1 , x2 , dim = axis , ** kwargs )
45
47
46
48
__all__ = linalg_all + ['outer' , 'trace' , 'matrix_transpose' , 'tensordot' , 'vecdot' ]
47
49
0 commit comments