-
Notifications
You must be signed in to change notification settings - Fork 35
Torch linalg #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch linalg #33
Changes from all commits
c55eb43
8ca03d2
e2d202f
1a79fb0
148a892
eb66b69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,27 @@ | ||
raise ImportError("The array api compat torch.linalg module extension is not yet implemented") | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
import torch | ||
array = torch.Tensor | ||
|
||
from torch.linalg import * | ||
|
||
# torch.linalg doesn't define __all__ | ||
# from torch.linalg import __all__ as linalg_all | ||
from torch import linalg as torch_linalg | ||
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] | ||
|
||
# These are implemented in torch but aren't in the linalg namespace | ||
from torch import outer, trace | ||
from ._aliases import _fix_promotion, matrix_transpose, tensordot | ||
|
||
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the | ||
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 | ||
Comment on lines
+19
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just in The default for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I didn't realize they were different. |
||
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: | ||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False) | ||
return torch_linalg.cross(x1, x2, dim=axis) | ||
|
||
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot'] | ||
|
||
del linalg_all |
Uh oh!
There was an error while loading. Please reload this page.