diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index e4f43c13..10c31bc6 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -6,13 +6,7 @@ # These imports may overwrite names from the import * above. from ._aliases import * -# Don't know why, but we have to do an absolute import to import linalg. If we -# instead do -# -# from . import linalg -# -# It doesn't overwrite cupy.linalg from above. The import is generated -# dynamically so that the library can be vendored. +# See the comment in the numpy __init__.py __import__(__package__ + '.linalg') from .linalg import matrix_transpose, vecdot diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 7dfdf482..18776f1a 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -14,6 +14,9 @@ # These imports may overwrite names from the import * above. from ._aliases import * +# See the comment in the numpy __init__.py +__import__(__package__ + '.linalg') + from ..common._helpers import * __array_api_version__ = '2021.12' diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index c23ba059..dbd4d8d9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -9,14 +9,15 @@ vecdot as _aliases_vecdot) from .._internal import get_xp +import torch + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import List, Optional, Sequence, Tuple, Union from ..common._typing import Device from torch import dtype as Dtype -import torch -array = torch.Tensor + array = torch.Tensor _int_dtypes = { torch.uint8, @@ -547,6 +548,14 @@ def empty(shape: Union[int, Tuple[int, ...]], **kwargs) -> array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) +# tril and triu do not call the keyword argument k + +def tril(x: array, /, *, k: int = 0) -> array: + return torch.tril(x, k) + +def triu(x: array, /, *, k: int = 0) -> array: + return torch.triu(x, k) + # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 def expand_dims(x: array, /, *, axis: int = 0) -> array: return torch.unsqueeze(x, axis) @@ -651,6 +660,7 @@ def isdtype( 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll', 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones', - 'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype'] + 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', + 'broadcast_arrays', 'unique_all', 'unique_counts', + 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', + 'vecdot', 'tensordot', 'isdtype'] diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 8d223fd4..c803228a 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1 +1,27 @@ -raise ImportError("The array api compat torch.linalg module extension is not yet implemented") +from __future__ import annotations + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import torch + array = torch.Tensor + +from torch.linalg import * + +# torch.linalg doesn't define __all__ +# from torch.linalg import __all__ as linalg_all +from torch import linalg as torch_linalg +linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] + +# These are implemented in torch but aren't in the linalg namespace +from torch import outer, trace +from ._aliases import _fix_promotion, matrix_transpose, tensordot + +# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the +# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 +def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch_linalg.cross(x1, x2, dim=axis) + +__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot'] + +del linalg_all diff --git a/torch-xfails.txt b/torch-xfails.txt index f67ca189..d6cf1670 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -26,24 +26,13 @@ array_api_tests/test_data_type_functions.py::test_iinfo[uint16] array_api_tests/test_data_type_functions.py::test_iinfo[uint32] array_api_tests/test_data_type_functions.py::test_iinfo[uint64] -# --disable-extension broken with test_has_names.py -# https://github.com/data-apis/array-api-tests/issues/169 -array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose] -array_api_tests/test_has_names.py::test_has_names[linalg-outer] -array_api_tests/test_has_names.py::test_has_names[linalg-tensordot] -array_api_tests/test_has_names.py::test_has_names[linalg-trace] - # We cannot wrap the tensor object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] - # tensordot doesn't allow integer dtypes in some corner cases array_api_tests/test_linalg.py::test_tensordot -# A numerical difference in stacking (will be fixed by -# https://github.com/data-apis/array-api-tests/pull/101) -array_api_tests/test_linalg.py::test_matmul # We cannot wrap the tensor object array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]