Skip to content

Commit b32a5b3

Browse files
authored
Merge pull request #33 from asmeurer/torch-linalg
Torch linalg
2 parents e6007fa + eb66b69 commit b32a5b3

File tree

5 files changed

+46
-24
lines changed

5 files changed

+46
-24
lines changed

array_api_compat/cupy/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,7 @@
66
# These imports may overwrite names from the import * above.
77
from ._aliases import *
88

9-
# Don't know why, but we have to do an absolute import to import linalg. If we
10-
# instead do
11-
#
12-
# from . import linalg
13-
#
14-
# It doesn't overwrite cupy.linalg from above. The import is generated
15-
# dynamically so that the library can be vendored.
9+
# See the comment in the numpy __init__.py
1610
__import__(__package__ + '.linalg')
1711

1812
from .linalg import matrix_transpose, vecdot

array_api_compat/torch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# These imports may overwrite names from the import * above.
1515
from ._aliases import *
1616

17+
# See the comment in the numpy __init__.py
18+
__import__(__package__ + '.linalg')
19+
1720
from ..common._helpers import *
1821

1922
__array_api_version__ = '2021.12'

array_api_compat/torch/_aliases.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
vecdot as _aliases_vecdot)
1010
from .._internal import get_xp
1111

12+
import torch
13+
1214
from typing import TYPE_CHECKING
1315
if TYPE_CHECKING:
1416
from typing import List, Optional, Sequence, Tuple, Union
1517
from ..common._typing import Device
1618
from torch import dtype as Dtype
1719

18-
import torch
19-
array = torch.Tensor
20+
array = torch.Tensor
2021

2122
_int_dtypes = {
2223
torch.uint8,
@@ -547,6 +548,14 @@ def empty(shape: Union[int, Tuple[int, ...]],
547548
**kwargs) -> array:
548549
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
549550

551+
# tril and triu do not call the keyword argument k
552+
553+
def tril(x: array, /, *, k: int = 0) -> array:
554+
return torch.tril(x, k)
555+
556+
def triu(x: array, /, *, k: int = 0) -> array:
557+
return torch.triu(x, k)
558+
550559
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
551560
def expand_dims(x: array, /, *, axis: int = 0) -> array:
552561
return torch.unsqueeze(x, axis)
@@ -651,6 +660,7 @@ def isdtype(
651660
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
652661
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
653662
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
654-
'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays',
655-
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
656-
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype']
663+
'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
664+
'broadcast_arrays', 'unique_all', 'unique_counts',
665+
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
666+
'vecdot', 'tensordot', 'isdtype']

array_api_compat/torch/linalg.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,27 @@
1-
raise ImportError("The array api compat torch.linalg module extension is not yet implemented")
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
if TYPE_CHECKING:
5+
import torch
6+
array = torch.Tensor
7+
8+
from torch.linalg import *
9+
10+
# torch.linalg doesn't define __all__
11+
# from torch.linalg import __all__ as linalg_all
12+
from torch import linalg as torch_linalg
13+
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
14+
15+
# These are implemented in torch but aren't in the linalg namespace
16+
from torch import outer, trace
17+
from ._aliases import _fix_promotion, matrix_transpose, tensordot
18+
19+
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
20+
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
21+
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
22+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
23+
return torch_linalg.cross(x1, x2, dim=axis)
24+
25+
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot']
26+
27+
del linalg_all

torch-xfails.txt

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,13 @@ array_api_tests/test_data_type_functions.py::test_iinfo[uint16]
2626
array_api_tests/test_data_type_functions.py::test_iinfo[uint32]
2727
array_api_tests/test_data_type_functions.py::test_iinfo[uint64]
2828

29-
# --disable-extension broken with test_has_names.py
30-
# https://github.com/data-apis/array-api-tests/issues/169
31-
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose]
32-
array_api_tests/test_has_names.py::test_has_names[linalg-outer]
33-
array_api_tests/test_has_names.py::test_has_names[linalg-tensordot]
34-
array_api_tests/test_has_names.py::test_has_names[linalg-trace]
35-
3629
# We cannot wrap the tensor object
3730
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
3831
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
3932

40-
4133
# tensordot doesn't allow integer dtypes in some corner cases
4234
array_api_tests/test_linalg.py::test_tensordot
4335

44-
# A numerical difference in stacking (will be fixed by
45-
# https://github.com/data-apis/array-api-tests/pull/101)
46-
array_api_tests/test_linalg.py::test_matmul
4736
# We cannot wrap the tensor object
4837
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
4938
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]

0 commit comments

Comments
 (0)