You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update dtype requirements for existing linear algebra APIs (#148)
* Add dtype requirements
* Loosen dtype requirements
* Update copy
* Restrict `det` to floating-point data types
For more than 3-dimensions, efficiently computing the determinant
relies on factorization, which is likely to be inexact and require
floating-point computation. Accordingly, we should restrict input
dtypes to avoid potential casting issues (e.g., int64 to float64).
Copy file name to clipboardExpand all lines: spec/API_specification/linear_algebra_functions.md
+14-14Lines changed: 14 additions & 14 deletions
Original file line number
Diff line number
Diff line change
@@ -29,11 +29,11 @@ Returns the cross product of 3-element vectors. If `x1` and `x2` are multi-dimen
29
29
30
30
-**x1**: _<array>_
31
31
32
-
- first input array. Must have a data type of either `float32` or `float64`.
32
+
- first input array. Should have a numeric data type.
33
33
34
34
-**x2**: _<array>_
35
35
36
-
- second input array. Must have the same shape as `x1`. Must have a data type of either `float32` or `float64`.
36
+
- second input array. Must have the same shape as `x1`. Should have a numeric data type.
37
37
38
38
-**axis**: _int_
39
39
@@ -43,7 +43,7 @@ Returns the cross product of 3-element vectors. If `x1` and `x2` are multi-dimen
43
43
44
44
-**out**: _<array>_
45
45
46
-
- an array containing the cross products. The returned array must have a data type determined by {ref}`type-promotion` rules.
46
+
- an array containing the cross products. The returned array must have a data type determined by {ref}`type-promotion`.
47
47
48
48
(function-det)=
49
49
### det(x, /)
@@ -54,13 +54,13 @@ Returns the determinant of a square matrix (or stack of square matrices) `x`.
54
54
55
55
-**x**: _<array>_
56
56
57
-
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices. Must have a data type of either `float32` or `float64`.
57
+
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices. Should have a floating-point data type.
58
58
59
59
#### Returns
60
60
61
61
-**out**: _<array>_
62
62
63
-
- if `x` is a two-dimensional array, a zero-dimensional array containing the determinant; otherwise, a non-zero dimensional array containing the determinant for each square matrix. The returned array must have a data type determined by {ref}`type-promotion` rules.
63
+
- if `x` is a two-dimensional array, a zero-dimensional array containing the determinant; otherwise, a non-zero dimensional array containing the determinant for each square matrix. The returned array must have the same data type as `x`.
64
64
65
65
(function-diagonal)=
66
66
### diagonal(x, /, *, axis1=0, axis2=1, offset=0)
@@ -120,19 +120,19 @@ TODO
120
120
(function-inv)=
121
121
### inv(x, /)
122
122
123
-
Computes the multiplicative inverse of a square matrix (or stack of square matrices) `x`.
123
+
Computes the multiplicative inverse of a square matrix (or a stack of square matrices) `x`.
124
124
125
125
#### Parameters
126
126
127
127
-**x**: _<array>_
128
128
129
-
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices. Must have a data type of either `float32` or `float64`.
129
+
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices. Should have a floating-point data type.
130
130
131
131
#### Returns
132
132
133
133
-**out**: _<array>_
134
134
135
-
- an array containing the multiplicative inverses. The returned array must have the same data type and shape as `x`.
135
+
- an array containing the multiplicative inverses. The returned array must have a floating-point data type determined by {ref}`type-promotion`and must have the same shape as `x`.
136
136
137
137
(function-lstsq)=
138
138
### lstsq()
@@ -163,7 +163,7 @@ Computes the matrix or vector norm of `x`.
163
163
164
164
-**x**: _<array>_
165
165
166
-
- input array. Must have a data type of either `float32` or `float64`.
166
+
- input array. Should have a floating-point data type.
167
167
168
168
-**axis**: _Optional\[ Union\[ int, Tuple\[ int, int ]]]_
169
169
@@ -231,7 +231,7 @@ Computes the matrix or vector norm of `x`.
231
231
232
232
-**out**: _<array>_
233
233
234
-
- an array containing the norms. If `axis` is `None`, the output array must be a zero-dimensional array containing a vector norm. If `axis` is a scalar value (`int` or `float`), the output array must have a rank which is one less than the rank of `x`. If `axis` is a 2-tuple, the output array must have a rank which is two less than the rank of `x`. The returned array must have the same data type as `x`.
234
+
- an array containing the norms. If `axis` is `None`, the returned array must be a zero-dimensional array containing a vector norm. If `axis` is a scalar value (`int` or `float`), the returned array must have a rank which is one less than the rank of `x`. If `axis` is a 2-tuple, the returned array must have a rank which is two less than the rank of `x`. The returned array must have a floating-point data type determined by {ref}`type-promotion`.
235
235
236
236
(function-outer)=
237
237
### outer(x1, x2, /)
@@ -242,17 +242,17 @@ Computes the outer product of two vectors `x1` and `x2`.
242
242
243
243
-**x1**: _<array>_
244
244
245
-
- first one-dimensional input array of size `N`. Must have a data type of either `float32` or `float64`.
245
+
- first one-dimensional input array of size `N`. Should have a numeric data type.
246
246
247
247
-**x2**: _<array>_
248
248
249
-
- second one-dimensional input array of size `M`. Must have a data type of either `float32` or `float64`.
249
+
- second one-dimensional input array of size `M`. Should have a numeric data type.
250
250
251
251
#### Returns
252
252
253
253
-**out**: _<array>_
254
254
255
-
- a two-dimensional array containing the outer product and whose shape is `NxM`. The returned array must have a data type determined by {ref}`type-promotion` rules.
255
+
- a two-dimensional array containing the outer product and whose shape is `(N, M)`. The returned array must have a data type determined by {ref}`type-promotion`.
256
256
257
257
(function-pinv)=
258
258
### pinv()
@@ -288,7 +288,7 @@ Returns the sum along the specified diagonals. If `x` has more than two dimensio
288
288
289
289
-**x**: _<array>_
290
290
291
-
- input array. Must have at least `2` dimensions.
291
+
- input array. Must have at least `2` dimensions. Should have a numeric data type.
0 commit comments