Skip to content

Torch linalg #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@
# These imports may overwrite names from the import * above.
from ._aliases import *

# Don't know why, but we have to do an absolute import to import linalg. If we
# instead do
#
# from . import linalg
#
# It doesn't overwrite cupy.linalg from above. The import is generated
# dynamically so that the library can be vendored.
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')

from .linalg import matrix_transpose, vecdot
Expand Down
3 changes: 3 additions & 0 deletions array_api_compat/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# These imports may overwrite names from the import * above.
from ._aliases import *

# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')

from ..common._helpers import *

__array_api_version__ = '2021.12'
20 changes: 15 additions & 5 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
vecdot as _aliases_vecdot)
from .._internal import get_xp

import torch

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import List, Optional, Sequence, Tuple, Union
from ..common._typing import Device
from torch import dtype as Dtype

import torch
array = torch.Tensor
array = torch.Tensor

_int_dtypes = {
torch.uint8,
Expand Down Expand Up @@ -547,6 +548,14 @@ def empty(shape: Union[int, Tuple[int, ...]],
**kwargs) -> array:
return torch.empty(shape, dtype=dtype, device=device, **kwargs)

# tril and triu do not call the keyword argument k

def tril(x: array, /, *, k: int = 0) -> array:
return torch.tril(x, k)

def triu(x: array, /, *, k: int = 0) -> array:
return torch.triu(x, k)

# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
def expand_dims(x: array, /, *, axis: int = 0) -> array:
return torch.unsqueeze(x, axis)
Expand Down Expand Up @@ -651,6 +660,7 @@ def isdtype(
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype']
'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
'broadcast_arrays', 'unique_all', 'unique_counts',
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
'vecdot', 'tensordot', 'isdtype']
28 changes: 27 additions & 1 deletion array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
@@ -1 +1,27 @@
raise ImportError("The array api compat torch.linalg module extension is not yet implemented")
from __future__ import annotations

from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
array = torch.Tensor

from torch.linalg import *

# torch.linalg doesn't define __all__
# from torch.linalg import __all__ as linalg_all
from torch import linalg as torch_linalg
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]

# These are implemented in torch but aren't in the linalg namespace
from torch import outer, trace
from ._aliases import _fix_promotion, matrix_transpose, tensordot

# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
Comment on lines +19 to +20
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just in torch.cross though. torch.linalg.cross has the saner dim=-1 default.

The default for torch.cross is easily in my top 5 of craziest behaviours of PyTorch core though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I didn't realize they were different.

def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch_linalg.cross(x1, x2, dim=axis)

__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot']

del linalg_all
11 changes: 0 additions & 11 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,13 @@ array_api_tests/test_data_type_functions.py::test_iinfo[uint16]
array_api_tests/test_data_type_functions.py::test_iinfo[uint32]
array_api_tests/test_data_type_functions.py::test_iinfo[uint64]

# --disable-extension broken with test_has_names.py
# https://github.com/data-apis/array-api-tests/issues/169
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose]
array_api_tests/test_has_names.py::test_has_names[linalg-outer]
array_api_tests/test_has_names.py::test_has_names[linalg-tensordot]
array_api_tests/test_has_names.py::test_has_names[linalg-trace]

# We cannot wrap the tensor object
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]


# tensordot doesn't allow integer dtypes in some corner cases
array_api_tests/test_linalg.py::test_tensordot

# A numerical difference in stacking (will be fixed by
# https://github.com/data-apis/array-api-tests/pull/101)
array_api_tests/test_linalg.py::test_matmul
# We cannot wrap the tensor object
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)]
Expand Down