Skip to content

Commit e8f7c24

Browse files
committed
Add aliases for diagonal() and trace()
1 parent 4d1829f commit e8f7c24

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

numpy_array_api_compat/linalg.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from numpy.linalg import __all__ as linalg_all
1313

1414
# These are in the main NumPy namespace but not in numpy.linalg
15-
from numpy import cross, diagonal, matmul, outer, tensordot, trace
15+
from numpy import cross, matmul, outer, tensordot
1616

1717
class EighResult(NamedTuple):
1818
eigenvalues: ndarray
@@ -141,6 +141,15 @@ def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] =
141141

142142
return res
143143

144+
# np.diagonal and np.trace operate on the first two axes whereas these
145+
# operates on the last two
146+
147+
def diagonal(x: ndarray, /, *, offset: int = 0) -> ndarray:
148+
return np.diagonal(x, offset=offset, axis1=-2, axis2=-1)
149+
150+
def trace(x: ndarray, /, *, offset: int = 0) -> ndarray:
151+
return np.asarray(np.trace(x, offset=offset, axis1=-2, axis2=-1))
152+
144153
__all__ = linalg_all.copy()
145154
__all__ += ['cross', 'diagonal', 'matmul', 'cholesky', 'matrix_rank', 'pinv',
146155
'matrix_norm', 'matrix_transpose', 'outer', 'svdvals',

0 commit comments

Comments
 (0)