Skip to content

Commit 88f46fa

Browse files
committed
Add the dtype argument to numpy.array_api.linalg.trace
Original NumPy Commit: 4e2a03ab936ac5035640df75e71965074c7d84c6
1 parent 83c30d9 commit 88f46fa

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

array_api_strict/linalg.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
from __future__ import annotations
22

3-
from ._dtypes import _floating_dtypes, _numeric_dtypes
3+
from ._dtypes import (
4+
_floating_dtypes,
5+
_numeric_dtypes,
6+
float32,
7+
float64,
8+
complex64,
9+
complex128
10+
)
411
from ._manipulation_functions import reshape
512
from ._array_object import Array
613

714
from numpy.core.numeric import normalize_axis_tuple
815

916
from typing import TYPE_CHECKING
1017
if TYPE_CHECKING:
11-
from ._typing import Literal, Optional, Sequence, Tuple, Union
18+
from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype
1219

1320
from typing import NamedTuple
1421

@@ -363,17 +370,25 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
363370
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
364371

365372
# Note: trace is the numpy top-level namespace, not np.linalg
366-
def trace(x: Array, /, *, offset: int = 0) -> Array:
373+
def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array:
367374
"""
368375
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
369376
370377
See its docstring for more information.
371378
"""
372379
if x.dtype not in _numeric_dtypes:
373380
raise TypeError('Only numeric dtypes are allowed in trace')
381+
382+
# Note: trace() works the same as sum() and prod() (see
383+
# _statistical_functions.py)
384+
if dtype is None:
385+
if x.dtype == float32:
386+
dtype = float64
387+
elif x.dtype == complex64:
388+
dtype = complex128
374389
# Note: trace always operates on the last two axes, whereas np.trace
375390
# operates on the first two axes by default
376-
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
391+
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)))
377392

378393
# Note: vecdot is not in NumPy
379394
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:

0 commit comments

Comments
 (0)