|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import TYPE_CHECKING |
| 4 | +if TYPE_CHECKING: |
| 5 | + import torch |
| 6 | + array = torch.Tensor |
| 7 | + |
1 | 8 | from torch.linalg import *
|
2 | 9 |
|
3 | 10 | # torch.linalg doesn't define __all__
|
4 | 11 | # from torch.linalg import __all__ as linalg_all
|
5 |
| -from torch import linalg as _linalg |
6 |
| -linalg_all = [i for i in dir(_linalg) if not i.startswith('_')] |
| 12 | +from torch import linalg as torch_linalg |
| 13 | +linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] |
7 | 14 |
|
8 | 15 | # These are implemented in torch but aren't in the linalg namespace
|
9 | 16 | from torch import outer, trace
|
10 |
| -from ._aliases import matrix_transpose, tensordot |
| 17 | +from ._aliases import _fix_promotion, matrix_transpose, tensordot |
| 18 | + |
| 19 | +# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the |
| 20 | +# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 |
| 21 | +def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: |
| 22 | + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) |
| 23 | + return torch_linalg.cross(x1, x2, dim=axis) |
11 | 24 |
|
12 | 25 | __all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot']
|
13 | 26 |
|
14 | 27 | del linalg_all
|
15 |
| -del _linalg |
|
0 commit comments