|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from ._dtypes import _floating_dtypes, _numeric_dtypes |
| 3 | +from ._dtypes import ( |
| 4 | + _floating_dtypes, |
| 5 | + _numeric_dtypes, |
| 6 | + float32, |
| 7 | + float64, |
| 8 | + complex64, |
| 9 | + complex128 |
| 10 | +) |
4 | 11 | from ._manipulation_functions import reshape
|
5 | 12 | from ._array_object import Array
|
6 | 13 |
|
7 | 14 | from numpy.core.numeric import normalize_axis_tuple
|
8 | 15 |
|
9 | 16 | from typing import TYPE_CHECKING
|
10 | 17 | if TYPE_CHECKING:
|
11 |
| - from ._typing import Literal, Optional, Sequence, Tuple, Union |
| 18 | + from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype |
12 | 19 |
|
13 | 20 | from typing import NamedTuple
|
14 | 21 |
|
@@ -363,17 +370,25 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
|
363 | 370 | return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
|
364 | 371 |
|
365 | 372 | # Note: trace is the numpy top-level namespace, not np.linalg
|
366 |
| -def trace(x: Array, /, *, offset: int = 0) -> Array: |
| 373 | +def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: |
367 | 374 | """
|
368 | 375 | Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
|
369 | 376 |
|
370 | 377 | See its docstring for more information.
|
371 | 378 | """
|
372 | 379 | if x.dtype not in _numeric_dtypes:
|
373 | 380 | raise TypeError('Only numeric dtypes are allowed in trace')
|
| 381 | + |
| 382 | + # Note: trace() works the same as sum() and prod() (see |
| 383 | + # _statistical_functions.py) |
| 384 | + if dtype is None: |
| 385 | + if x.dtype == float32: |
| 386 | + dtype = float64 |
| 387 | + elif x.dtype == complex64: |
| 388 | + dtype = complex128 |
374 | 389 | # Note: trace always operates on the last two axes, whereas np.trace
|
375 | 390 | # operates on the first two axes by default
|
376 |
| - return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) |
| 391 | + return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) |
377 | 392 |
|
378 | 393 | # Note: vecdot is not in NumPy
|
379 | 394 | def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
|
|
0 commit comments