diff --git a/spec/API_specification/linear_algebra_functions.md b/spec/API_specification/linear_algebra_functions.md index 209a2f787..781ae2806 100644 --- a/spec/API_specification/linear_algebra_functions.md +++ b/spec/API_specification/linear_algebra_functions.md @@ -279,6 +279,39 @@ TODO TODO +(function-tensordot)= +### tensordot(x1, x2, /, *, axes=2) + +Returns a tensor contraction of `x1` and `x2` over specific axes. + +#### Parameters + +- **x1**: _<array>_ + + - first input array. Should have a numeric data type. + +- **x2**: _<array>_ + + - second input array. Must be compatible with `x1` (see {ref}`broadcasting`). Should have a numeric data type. + +- **axes**: _Union\[ int, Tuple\[ Sequence\[ int ], Sequence\[ int ] ] ]_ + + - number of axes (dimensions) to contract or explicit sequences of axes (dimensions) for `x1` and `x2`, respectively. + + If `axes` is an `int` equal to `N`, then contraction must be performed over the last `N` axes of `x1` and the first `N` axes of `x2` in order. The size of each corresponding axis (dimension) must match. Must be nonnegative. + + - If `N` equals `0`, the result is the tensor (outer) product. + - If `N` equals `1`, the result is the tensor dot product. + - If `N` equals `2`, the result is the tensor double contraction (default). + + If `axes` is a tuple of two sequences `(x1_axes, x2_axes)`, the first sequence must apply to `x` and the second sequence to `x2`. Both sequences must have the same length. Each axis (dimension) `x1_axes[i]` for `x1` must have the same size as the respective axis (dimension) `x2_axes[i]` for `x2`. Each sequence must consist of unique (nonnegative) integers that specify valid axes for each respective array. + +#### Returns + +- **out**: _<array>_ + + - an array containing the tensor contraction whose shape consists of the non-contracted axes (dimensions) of the first array `x1`, followed by the non-contracted axes (dimensions) of the second array `x2`. The returned array must have a data type determined by {ref}`type-promotion`. + (function-trace)= ### trace(x, /, *, axis1=0, axis2=1, offset=0)