Skip to content

Commit 148a892

Browse files
committed
Define torch.linalg.cross wrapper
1 parent 1a79fb0 commit 148a892

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

array_api_compat/torch/linalg.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
if TYPE_CHECKING:
5+
import torch
6+
array = torch.Tensor
7+
18
from torch.linalg import *
29

310
# torch.linalg doesn't define __all__
411
# from torch.linalg import __all__ as linalg_all
5-
from torch import linalg as _linalg
6-
linalg_all = [i for i in dir(_linalg) if not i.startswith('_')]
12+
from torch import linalg as torch_linalg
13+
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
714

815
# These are implemented in torch but aren't in the linalg namespace
916
from torch import outer, trace
10-
from ._aliases import matrix_transpose, tensordot
17+
from ._aliases import _fix_promotion, matrix_transpose, tensordot
18+
19+
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
20+
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
21+
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
22+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
23+
return torch_linalg.cross(x1, x2, dim=axis)
1124

1225
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot']
1326

1427
del linalg_all
15-
del _linalg

0 commit comments

Comments
 (0)