Skip to content

Commit 02fae2c

Browse files
authored
Fix matmul broadcasting guidance (#328)
* Add commas * Clarify broadcasting rules
1 parent 4c041ea commit 02fae2c

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

spec/API_specification/array_object.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -870,23 +870,23 @@ The `matmul` function must implement the same semantics as the built-in `@` oper
870870

871871
- **self**: _<array>_
872872

873-
- array instance. Should have a numeric data type. Must have at least one dimension. If `self` is one-dimensional having shape `(M)` and `other` has more than one dimension, `self` must be promoted to a two-dimensional array by prepending `1` to its dimensions (i.e., must have shape `(1, M)`). After matrix multiplication, the prepended dimensions in the returned array must be removed. If `self` has more than one dimension (including after vector-to-matrix promotion), `self` must be compatible with `other` (see {ref}`broadcasting`). If `self` has shape `(..., M, K)`, the innermost two dimensions form matrices on which to perform matrix multiplication.
873+
- array instance. Should have a numeric data type. Must have at least one dimension. If `self` is one-dimensional having shape `(M,)` and `other` has more than one dimension, `self` must be promoted to a two-dimensional array by prepending `1` to its dimensions (i.e., must have shape `(1, M)`). After matrix multiplication, the prepended dimensions in the returned array must be removed. If `self` has more than one dimension (including after vector-to-matrix promotion), `shape(self)[:-2]` must be compatible with `shape(other)[:-2]` (after vector-to-matrix promotion) (see {ref}`broadcasting`). If `self` has shape `(..., M, K)`, the innermost two dimensions form matrices on which to perform matrix multiplication.
874874

875875
- **other**: _<array>_
876876

877-
- other array. Should have a numeric data type. Must have at least one dimension. If `other` is one-dimensional having shape `(N)` and `self` has more than one dimension, `other` must be promoted to a two-dimensional array by appending `1` to its dimensions (i.e., must have shape `(N, 1)`). After matrix multiplication, the appended dimensions in the returned array must be removed. If `other` has more than one dimension (including after vector-to-matrix promotion), `other` must be compatible with `self` (see {ref}`broadcasting`). If `other` has shape `(..., K, N)`, the innermost two dimensions form matrices on which to perform matrix multiplication.
877+
- other array. Should have a numeric data type. Must have at least one dimension. If `other` is one-dimensional having shape `(N,)` and `self` has more than one dimension, `other` must be promoted to a two-dimensional array by appending `1` to its dimensions (i.e., must have shape `(N, 1)`). After matrix multiplication, the appended dimensions in the returned array must be removed. If `other` has more than one dimension (including after vector-to-matrix promotion), `shape(other)[:-2]` must be compatible with `shape(self)[:-2]` (after vector-to-matrix promotion) (see {ref}`broadcasting`). If `other` has shape `(..., K, N)`, the innermost two dimensions form matrices on which to perform matrix multiplication.
878878

879879
#### Returns
880880

881881
- **out**: _<array>_
882882

883-
- if both `self` and `other` are one-dimensional arrays having shape `(N)`, a zero-dimensional array containing the inner product as its only element.
883+
- if both `self` and `other` are one-dimensional arrays having shape `(N,)`, a zero-dimensional array containing the inner product as its only element.
884884
- if `self` is a two-dimensional array having shape `(M, K)` and `other` is a two-dimensional array having shape `(K, N)`, a two-dimensional array containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) and having shape `(M, N)`.
885-
- if `self` is a one-dimensional array having shape `(K)` and `other` is an array having shape `(..., K, N)`, an array having shape `(..., N)` (i.e., prepended dimensions during vector-to-matrix promotion must be removed) and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication).
886-
- if `self` is an array having shape `(..., M, K)` and `other` is a one-dimensional array having shape `(K)`, an array having shape `(..., M)` (i.e., appended dimensions during vector-to-matrix promotion must be removed) and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication).
885+
- if `self` is a one-dimensional array having shape `(K,)` and `other` is an array having shape `(..., K, N)`, an array having shape `(..., N)` (i.e., prepended dimensions during vector-to-matrix promotion must be removed) and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication).
886+
- if `self` is an array having shape `(..., M, K)` and `other` is a one-dimensional array having shape `(K,)`, an array having shape `(..., M)` (i.e., appended dimensions during vector-to-matrix promotion must be removed) and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication).
887887
- if `self` is a two-dimensional array having shape `(M, K)` and `other` is an array having shape `(..., K, N)`, an array having shape `(..., M, N)` and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) for each stacked matrix.
888888
- if `self` is an array having shape `(..., M, K)` and `other` is a two-dimensional array having shape `(K, N)`, an array having shape `(..., M, N)` and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) for each stacked matrix.
889-
- if either `self` or `other` has more than two dimensions, an array having a shape determined by {ref}`broadcasting` `self` against `other` and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) for each stacked matrix.
889+
- if either `self` or `other` has more than two dimensions, an array having a shape determined by {ref}`broadcasting` `shape(self)[:-2]` against `shape(other)[:-2]` and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) for each stacked matrix.
890890

891891
The returned array must have a data type determined by {ref}`type-promotion`.
892892

@@ -897,9 +897,9 @@ The `matmul` function must implement the same semantics as the built-in `@` oper
897897
#### Raises
898898
899899
- if either `self` or `other` is a zero-dimensional array.
900-
- if `self` is a one-dimensional array having shape `(K)`, `other` is a one-dimensional array having shape `(L)`, and `K != L`.
901-
- if `self` is a one-dimensional array having shape `(K)`, `other` is an array having shape `(..., L, N)`, and `K != L`.
902-
- if `self` is an array having shape `(..., M, K)`, `other` is a one-dimensional array having shape `(L)`, and `K != L`.
900+
- if `self` is a one-dimensional array having shape `(K,)`, `other` is a one-dimensional array having shape `(L,)`, and `K != L`.
901+
- if `self` is a one-dimensional array having shape `(K,)`, `other` is an array having shape `(..., L, N)`, and `K != L`.
902+
- if `self` is an array having shape `(..., M, K)`, `other` is a one-dimensional array having shape `(L,)`, and `K != L`.
903903
- if `self` is an array having shape `(..., M, K)`, `other` is an array having shape `(..., L, N)`, and `K != L`.
904904
905905
(method-__mod__)=

spec/API_specification/linear_algebra_functions.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,32 +29,32 @@ The `matmul` function must implement the same semantics as the built-in `@` oper
2929

3030
- **x1**: _<array>_
3131

32-
- first input array. Should have a numeric data type. Must have at least one dimension. If `x1` is one-dimensional having shape `(M)` and `x2` has more than one dimension, `x1` must be promoted to a two-dimensional array by prepending `1` to its dimensions (i.e., must have shape `(1, M)`). After matrix multiplication, the prepended dimensions in the returned array must be removed. If `x1` has more than one dimension (including after vector-to-matrix promotion), `x1` must be compatible with `x2` (see {ref}`broadcasting`). If `x1` has shape `(..., M, K)`, the innermost two dimensions form matrices on which to perform matrix multiplication.
32+
- first input array. Should have a numeric data type. Must have at least one dimension. If `x1` is one-dimensional having shape `(M,)` and `x2` has more than one dimension, `x1` must be promoted to a two-dimensional array by prepending `1` to its dimensions (i.e., must have shape `(1, M)`). After matrix multiplication, the prepended dimensions in the returned array must be removed. If `x1` has more than one dimension (including after vector-to-matrix promotion), `shape(x1)[:-2]` must be compatible with `shape(x2)[:-2]` (after vector-to-matrix promotion) (see {ref}`broadcasting`). If `x1` has shape `(..., M, K)`, the innermost two dimensions form matrices on which to perform matrix multiplication.
3333

3434
- **x2**: _<array>_
3535

36-
- second input array. Should have a numeric data type. Must have at least one dimension. If `x2` is one-dimensional having shape `(N)` and `x1` has more than one dimension, `x2` must be promoted to a two-dimensional array by appending `1` to its dimensions (i.e., must have shape `(N, 1)`). After matrix multiplication, the appended dimensions in the returned array must be removed. If `x2` has more than one dimension (including after vector-to-matrix promotion), `x2` must be compatible with `x1` (see {ref}`broadcasting`). If `x2` has shape `(..., K, N)`, the innermost two dimensions form matrices on which to perform matrix multiplication.
36+
- second input array. Should have a numeric data type. Must have at least one dimension. If `x2` is one-dimensional having shape `(N,)` and `x1` has more than one dimension, `x2` must be promoted to a two-dimensional array by appending `1` to its dimensions (i.e., must have shape `(N, 1)`). After matrix multiplication, the appended dimensions in the returned array must be removed. If `x2` has more than one dimension (including after vector-to-matrix promotion), `shape(x2)[:-2]` must be compatible with `shape(x1)[:-2]` (after vector-to-matrix promotion) (see {ref}`broadcasting`). If `x2` has shape `(..., K, N)`, the innermost two dimensions form matrices on which to perform matrix multiplication.
3737

3838
#### Returns
3939

4040
- **out**: _<array>_
4141

42-
- if both `x1` and `x2` are one-dimensional arrays having shape `(N)`, a zero-dimensional array containing the inner product as its only element.
42+
- if both `x1` and `x2` are one-dimensional arrays having shape `(N,)`, a zero-dimensional array containing the inner product as its only element.
4343
- if `x1` is a two-dimensional array having shape `(M, K)` and `x2` is a two-dimensional array having shape `(K, N)`, a two-dimensional array containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) and having shape `(M, N)`.
44-
- if `x1` is a one-dimensional array having shape `(K)` and `x2` is an array having shape `(..., K, N)`, an array having shape `(..., N)` (i.e., prepended dimensions during vector-to-matrix promotion must be removed) and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication).
45-
- if `x1` is an array having shape `(..., M, K)` and `x2` is a one-dimensional array having shape `(K)`, an array having shape `(..., M)` (i.e., appended dimensions during vector-to-matrix promotion must be removed) and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication).
44+
- if `x1` is a one-dimensional array having shape `(K,)` and `x2` is an array having shape `(..., K, N)`, an array having shape `(..., N)` (i.e., prepended dimensions during vector-to-matrix promotion must be removed) and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication).
45+
- if `x1` is an array having shape `(..., M, K)` and `x2` is a one-dimensional array having shape `(K,)`, an array having shape `(..., M)` (i.e., appended dimensions during vector-to-matrix promotion must be removed) and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication).
4646
- if `x1` is a two-dimensional array having shape `(M, K)` and `x2` is an array having shape `(..., K, N)`, an array having shape `(..., M, N)` and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) for each stacked matrix.
4747
- if `x1` is an array having shape `(..., M, K)` and `x2` is a two-dimensional array having shape `(K, N)`, an array having shape `(..., M, N)` and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) for each stacked matrix.
48-
- if either `x1` or `x2` has more than two dimensions, an array having a shape determined by {ref}`broadcasting` `x1` against `x2` and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) for each stacked matrix.
48+
- if either `x1` or `x2` has more than two dimensions, an array having a shape determined by {ref}`broadcasting` `shape(x1)[:-2]` against `shape(x2)[:-2]` and containing the [conventional matrix product](https://en.wikipedia.org/wiki/Matrix_multiplication) for each stacked matrix.
4949

5050
The returned array must have a data type determined by {ref}`type-promotion`.
5151

5252
#### Raises
5353

5454
- if either `x1` or `x2` is a zero-dimensional array.
55-
- if `x1` is a one-dimensional array having shape `(K)`, `x2` is a one-dimensional array having shape `(L)`, and `K != L`.
56-
- if `x1` is a one-dimensional array having shape `(K)`, `x2` is an array having shape `(..., L, N)`, and `K != L`.
57-
- if `x1` is an array having shape `(..., M, K)`, `x2` is a one-dimensional array having shape `(L)`, and `K != L`.
55+
- if `x1` is a one-dimensional array having shape `(K,)`, `x2` is a one-dimensional array having shape `(L,)`, and `K != L`.
56+
- if `x1` is a one-dimensional array having shape `(K,)`, `x2` is an array having shape `(..., L, N)`, and `K != L`.
57+
- if `x1` is an array having shape `(..., M, K)`, `x2` is a one-dimensional array having shape `(L,)`, and `K != L`.
5858
- if `x1` is an array having shape `(..., M, K)`, `x2` is an array having shape `(..., L, N)`, and `K != L`.
5959

6060
(function-matrix-transpose)=

0 commit comments

Comments
 (0)