Skip to content

Commit c0911bb

Browse files
authored
Update dtype requirements for existing linear algebra APIs (#148)
* Add dtype requirements * Loosen dtype requirements * Update copy * Restrict `det` to floating-point data types For more than 3-dimensions, efficiently computing the determinant relies on factorization, which is likely to be inexact and require floating-point computation. Accordingly, we should restrict input dtypes to avoid potential casting issues (e.g., int64 to float64).
1 parent 0941067 commit c0911bb

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

spec/API_specification/linear_algebra_functions.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ Returns the cross product of 3-element vectors. If `x1` and `x2` are multi-dimen
2929

3030
- **x1**: _<array>_
3131

32-
- first input array. Must have a data type of either `float32` or `float64`.
32+
- first input array. Should have a numeric data type.
3333

3434
- **x2**: _<array>_
3535

36-
- second input array. Must have the same shape as `x1`. Must have a data type of either `float32` or `float64`.
36+
- second input array. Must have the same shape as `x1`. Should have a numeric data type.
3737

3838
- **axis**: _int_
3939

@@ -43,7 +43,7 @@ Returns the cross product of 3-element vectors. If `x1` and `x2` are multi-dimen
4343

4444
- **out**: _<array>_
4545

46-
- an array containing the cross products. The returned array must have a data type determined by {ref}`type-promotion` rules.
46+
- an array containing the cross products. The returned array must have a data type determined by {ref}`type-promotion`.
4747

4848
(function-det)=
4949
### det(x, /)
@@ -54,13 +54,13 @@ Returns the determinant of a square matrix (or stack of square matrices) `x`.
5454

5555
- **x**: _<array>_
5656

57-
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices. Must have a data type of either `float32` or `float64`.
57+
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices. Should have a floating-point data type.
5858

5959
#### Returns
6060

6161
- **out**: _<array>_
6262

63-
- if `x` is a two-dimensional array, a zero-dimensional array containing the determinant; otherwise, a non-zero dimensional array containing the determinant for each square matrix. The returned array must have a data type determined by {ref}`type-promotion` rules.
63+
- if `x` is a two-dimensional array, a zero-dimensional array containing the determinant; otherwise, a non-zero dimensional array containing the determinant for each square matrix. The returned array must have the same data type as `x`.
6464

6565
(function-diagonal)=
6666
### diagonal(x, /, *, axis1=0, axis2=1, offset=0)
@@ -120,19 +120,19 @@ TODO
120120
(function-inv)=
121121
### inv(x, /)
122122

123-
Computes the multiplicative inverse of a square matrix (or stack of square matrices) `x`.
123+
Computes the multiplicative inverse of a square matrix (or a stack of square matrices) `x`.
124124

125125
#### Parameters
126126

127127
- **x**: _<array>_
128128

129-
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices. Must have a data type of either `float32` or `float64`.
129+
- input array having shape `(..., M, M)` and whose innermost two dimensions form square matrices. Should have a floating-point data type.
130130

131131
#### Returns
132132

133133
- **out**: _<array>_
134134

135-
- an array containing the multiplicative inverses. The returned array must have the same data type and shape as `x`.
135+
- an array containing the multiplicative inverses. The returned array must have a floating-point data type determined by {ref}`type-promotion` and must have the same shape as `x`.
136136

137137
(function-lstsq)=
138138
### lstsq()
@@ -163,7 +163,7 @@ Computes the matrix or vector norm of `x`.
163163

164164
- **x**: _<array>_
165165

166-
- input array. Must have a data type of either `float32` or `float64`.
166+
- input array. Should have a floating-point data type.
167167

168168
- **axis**: _Optional\[ Union\[ int, Tuple\[ int, int ] ] ]_
169169

@@ -231,7 +231,7 @@ Computes the matrix or vector norm of `x`.
231231

232232
- **out**: _<array>_
233233

234-
- an array containing the norms. If `axis` is `None`, the output array must be a zero-dimensional array containing a vector norm. If `axis` is a scalar value (`int` or `float`), the output array must have a rank which is one less than the rank of `x`. If `axis` is a 2-tuple, the output array must have a rank which is two less than the rank of `x`. The returned array must have the same data type as `x`.
234+
- an array containing the norms. If `axis` is `None`, the returned array must be a zero-dimensional array containing a vector norm. If `axis` is a scalar value (`int` or `float`), the returned array must have a rank which is one less than the rank of `x`. If `axis` is a 2-tuple, the returned array must have a rank which is two less than the rank of `x`. The returned array must have a floating-point data type determined by {ref}`type-promotion`.
235235

236236
(function-outer)=
237237
### outer(x1, x2, /)
@@ -242,17 +242,17 @@ Computes the outer product of two vectors `x1` and `x2`.
242242

243243
- **x1**: _<array>_
244244

245-
- first one-dimensional input array of size `N`. Must have a data type of either `float32` or `float64`.
245+
- first one-dimensional input array of size `N`. Should have a numeric data type.
246246

247247
- **x2**: _<array>_
248248

249-
- second one-dimensional input array of size `M`. Must have a data type of either `float32` or `float64`.
249+
- second one-dimensional input array of size `M`. Should have a numeric data type.
250250

251251
#### Returns
252252

253253
- **out**: _<array>_
254254

255-
- a two-dimensional array containing the outer product and whose shape is `NxM`. The returned array must have a data type determined by {ref}`type-promotion` rules.
255+
- a two-dimensional array containing the outer product and whose shape is `(N, M)`. The returned array must have a data type determined by {ref}`type-promotion`.
256256

257257
(function-pinv)=
258258
### pinv()
@@ -288,7 +288,7 @@ Returns the sum along the specified diagonals. If `x` has more than two dimensio
288288

289289
- **x**: _<array>_
290290

291-
- input array. Must have at least `2` dimensions.
291+
- input array. Must have at least `2` dimensions. Should have a numeric data type.
292292

293293
- **axis1**: _int_
294294

0 commit comments

Comments
 (0)