From 7728c985242cc346e8d740df649f7c7c4452d8c6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 1 Feb 2024 18:28:56 -0700 Subject: [PATCH] Only require axis to be negative in vecdot and cross Nonnegative axes and negative axes less than the smaller of the two arrays are unspecified. This is because it is ambiguous in these cases whether the dimension should refer to the axis before or after broadcasting. Preciously, the spec stated it should refer to the dimension before broadcasting, but this deviates from NumPy gufunc behavior, and results in ambiguous and confusing situations, where, for instance, the result of a the function is different when the inputs are manually broadcasted together. Also clean up some of the cross text a little bit since the computed dimension must be exactly size 3. Fixes #724 Fixes #617 See the discussion in those issues for more details. --- src/array_api_stubs/_draft/linalg.py | 9 ++++----- src/array_api_stubs/_draft/linear_algebra_functions.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/array_api_stubs/_draft/linalg.py b/src/array_api_stubs/_draft/linalg.py index d05b53a9f..1e7efa95e 100644 --- a/src/array_api_stubs/_draft/linalg.py +++ b/src/array_api_stubs/_draft/linalg.py @@ -83,15 +83,15 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: Parameters ---------- x1: array - first input array. Must have a numeric data type. + first input array. Must have a numeric data type. The size of the axis over which the cross product is to be computed must be equal to 3. x2: array - second input array. Must be compatible with ``x1`` for all non-compute axes (see :ref:`broadcasting`). The size of the axis over which to compute the cross product must be the same size as the respective axis in ``x1``. Must have a numeric data type. + second input array. Must be broadcast compatible with ``x1`` along all axes other than the axis along which the cross-product is computed (see :ref:`broadcasting`). The size of the axis over which the cross product is to be computed must be equal to 3. Must have a numeric data type. .. note:: The compute axis (dimension) must not be broadcasted. axis: int - the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``. + the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the cross product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the cross product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: ``-1``. Returns ------- @@ -110,8 +110,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: **Raises** - - if the size of the axis over which to compute the cross product is not equal to ``3``. - - if the size of the axis over which to compute the cross product is not the same (before broadcasting) for both ``x1`` and ``x2``. + - if the size of the axis over which to compute the cross product is not equal to ``3`` (before broadcasting) for both ``x1`` and ``x2``. """ diff --git a/src/array_api_stubs/_draft/linear_algebra_functions.py b/src/array_api_stubs/_draft/linear_algebra_functions.py index 96f082bd5..eea898a6b 100644 --- a/src/array_api_stubs/_draft/linear_algebra_functions.py +++ b/src/array_api_stubs/_draft/linear_algebra_functions.py @@ -141,7 +141,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: The contracted axis (dimension) must not be broadcasted. axis: int - axis over which to compute the dot product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``. + the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``. Returns -------