|
12 | 12 | from numpy.linalg import __all__ as linalg_all
|
13 | 13 |
|
14 | 14 | # These are in the main NumPy namespace but not in numpy.linalg
|
15 |
| -from numpy import cross, diagonal, matmul, outer, tensordot, trace |
| 15 | +from numpy import cross, matmul, outer, tensordot |
16 | 16 |
|
17 | 17 | class EighResult(NamedTuple):
|
18 | 18 | eigenvalues: ndarray
|
@@ -141,6 +141,15 @@ def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] =
|
141 | 141 |
|
142 | 142 | return res
|
143 | 143 |
|
| 144 | +# np.diagonal and np.trace operate on the first two axes whereas these |
| 145 | +# operates on the last two |
| 146 | + |
| 147 | +def diagonal(x: ndarray, /, *, offset: int = 0) -> ndarray: |
| 148 | + return np.diagonal(x, offset=offset, axis1=-2, axis2=-1) |
| 149 | + |
| 150 | +def trace(x: ndarray, /, *, offset: int = 0) -> ndarray: |
| 151 | + return np.asarray(np.trace(x, offset=offset, axis1=-2, axis2=-1)) |
| 152 | + |
144 | 153 | __all__ = linalg_all.copy()
|
145 | 154 | __all__ += ['cross', 'diagonal', 'matmul', 'cholesky', 'matrix_rank', 'pinv',
|
146 | 155 | 'matrix_norm', 'matrix_transpose', 'outer', 'svdvals',
|
|
0 commit comments