From f9bd18520e792029ec777f3b525516d2006f82a1 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 21 Apr 2023 16:15:59 -0600 Subject: [PATCH 1/2] Allow negative axes in tensordot Both NumPy and PyTorch allow this, and there should be no ambiguity or difficulty in doing so, as long as the specified axes remain unique. --- src/array_api_stubs/_draft/linear_algebra_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_stubs/_draft/linear_algebra_functions.py b/src/array_api_stubs/_draft/linear_algebra_functions.py index d9ac2437c..c2d346c51 100644 --- a/src/array_api_stubs/_draft/linear_algebra_functions.py +++ b/src/array_api_stubs/_draft/linear_algebra_functions.py @@ -97,7 +97,7 @@ def tensordot( - If ``N`` equals ``1``, the result is the tensor dot product. - If ``N`` equals ``2``, the result is the tensor double contraction (default). - If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x1`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each sequence must consist of unique (nonnegative) integers that specify valid axes for each respective array. + If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x1`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each sequence must consist of integers that uniquely specify valid axes for each respective array. .. note:: From ba0b357b7e45fb03fc98cf9d65f77260a6977e27 Mon Sep 17 00:00:00 2001 From: Athan Date: Tue, 19 Sep 2023 17:19:25 -0700 Subject: [PATCH 2/2] Explicitly specify allowed value ranges This follows precedent elsewhere (e.g., `moveaxis`). --- src/array_api_stubs/_draft/linear_algebra_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_draft/linear_algebra_functions.py b/src/array_api_stubs/_draft/linear_algebra_functions.py index c2d346c51..64b46aa97 100644 --- a/src/array_api_stubs/_draft/linear_algebra_functions.py +++ b/src/array_api_stubs/_draft/linear_algebra_functions.py @@ -89,7 +89,7 @@ def tensordot( Contracted axes (dimensions) must not be broadcasted. axes: Union[int, Tuple[Sequence[int], Sequence[int]]] - number of axes (dimensions) to contract or explicit sequences of axes (dimensions) for ``x1`` and ``x2``, respectively. + number of axes (dimensions) to contract or explicit sequences of axis (dimension) indices for ``x1`` and ``x2``, respectively. If ``axes`` is an ``int`` equal to ``N``, then contraction must be performed over the last ``N`` axes of ``x1`` and the first ``N`` axes of ``x2`` in order. The size of each corresponding axis (dimension) must match. Must be nonnegative. @@ -97,7 +97,7 @@ def tensordot( - If ``N`` equals ``1``, the result is the tensor dot product. - If ``N`` equals ``2``, the result is the tensor double contraction (default). - If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x1`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each sequence must consist of integers that uniquely specify valid axes for each respective array. + If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x1`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each index referred to in a sequence must be unique. If ``x1`` has rank (i.e, number of dimensions) ``N``, a valid ``x1`` axis must reside on the half-open interval ``[-N, N)``. If ``x2`` has rank ``M``, a valid ``x2`` axis must reside on the half-open interval ``[-M, M)``. .. note::