Skip to content

Commit b933915

Browse files
authored
docs: clarify broadcasting semantics and output shape in linalg.solve (#810)
PR-URL: #810
1 parent b569b03 commit b933915

File tree

5 files changed

+8
-8
lines changed

5 files changed

+8
-8
lines changed

src/array_api_stubs/_2022_12/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,12 +594,12 @@ def solve(x1: array, x2: array, /) -> array:
594594
x1: array
595595
coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Must be of full rank (i.e., all rows or, equivalently, columns must be linearly independent). Should have a floating-point data type.
596596
x2: array
597-
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-1]`` must be compatible with ``shape(x1)[:-1]`` (see :ref:`broadcasting`). Should have a floating-point data type.
597+
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (see :ref:`broadcasting`). Should have a floating-point data type.
598598
599599
Returns
600600
-------
601601
out: array
602-
an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`.
602+
an array containing the solution to the system ``AX = B`` for each square matrix. If ``x2`` has shape ``(M,)``, the returned array must have shape equal to ``shape(x1)[:-2] + shape(x2)[-1:]``. Otherwise, if ``x2`` has shape ``(..., M, K)```, the returned array must have shape equal to ``(..., M, K)``, where ``...`` refers to the result of broadcasting ``shape(x1)[:-2]`` and ``shape(x2)[:-2]``. The returned array must have a floating-point data type determined by :ref:`type-promotion`.
603603
604604
Notes
605605
-----

src/array_api_stubs/_2023_12/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,12 +623,12 @@ def solve(x1: array, x2: array, /) -> array:
623623
x1: array
624624
coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Must be of full rank (i.e., all rows or, equivalently, columns must be linearly independent). Should have a floating-point data type.
625625
x2: array
626-
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-1]`` must be compatible with ``shape(x1)[:-1]`` (see :ref:`broadcasting`). Should have a floating-point data type.
626+
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (see :ref:`broadcasting`). Should have a floating-point data type.
627627
628628
Returns
629629
-------
630630
out: array
631-
an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`.
631+
an array containing the solution to the system ``AX = B`` for each square matrix. If ``x2`` has shape ``(M,)``, the returned array must have shape equal to ``shape(x1)[:-2] + shape(x2)[-1:]``. Otherwise, if ``x2`` has shape ``(..., M, K)```, the returned array must have shape equal to ``(..., M, K)``, where ``...`` refers to the result of broadcasting ``shape(x1)[:-2]`` and ``shape(x2)[:-2]``. The returned array must have a floating-point data type determined by :ref:`type-promotion`.
632632
633633
Notes
634634
-----

src/array_api_stubs/_2023_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def tile(x: array, repetitions: Tuple[int, ...], /) -> array:
347347

348348
def unstack(x: array, /, *, axis: int = 0) -> Tuple[array, ...]:
349349
"""
350-
Splits an array in a sequence of arrays along the given axis.
350+
Splits an array into a sequence of arrays along the given axis.
351351
352352
Parameters
353353
----------

src/array_api_stubs/_draft/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,12 +623,12 @@ def solve(x1: array, x2: array, /) -> array:
623623
x1: array
624624
coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Must be of full rank (i.e., all rows or, equivalently, columns must be linearly independent). Should have a floating-point data type.
625625
x2: array
626-
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-1]`` must be compatible with ``shape(x1)[:-1]`` (see :ref:`broadcasting`). Should have a floating-point data type.
626+
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (see :ref:`broadcasting`). Should have a floating-point data type.
627627
628628
Returns
629629
-------
630630
out: array
631-
an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`.
631+
an array containing the solution to the system ``AX = B`` for each square matrix. If ``x2`` has shape ``(M,)``, the returned array must have shape equal to ``shape(x1)[:-2] + shape(x2)[-1:]``. Otherwise, if ``x2`` has shape ``(..., M, K)```, the returned array must have shape equal to ``(..., M, K)``, where ``...`` refers to the result of broadcasting ``shape(x1)[:-2]`` and ``shape(x2)[:-2]``. The returned array must have a floating-point data type determined by :ref:`type-promotion`.
632632
633633
Notes
634634
-----

src/array_api_stubs/_draft/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def tile(x: array, repetitions: Tuple[int, ...], /) -> array:
347347

348348
def unstack(x: array, /, *, axis: int = 0) -> Tuple[array, ...]:
349349
"""
350-
Splits an array in a sequence of arrays along the given axis.
350+
Splits an array into a sequence of arrays along the given axis.
351351
352352
Parameters
353353
----------

0 commit comments

Comments
 (0)