Skip to content

Commit 31c94bb

Browse files
committed
Fix a test failure with torch.vector_norm
Also cleanup the torch.linalg __all__
1 parent 93ce826 commit 31c94bb

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

array_api_compat/torch/linalg.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import torch
66
array = torch.Tensor
77
from torch import dtype as Dtype
8-
from typing import Optional
8+
from typing import Optional, Union, Tuple, Literal
9+
inf = float('inf')
910

1011
from ._aliases import _fix_promotion, sum
1112

@@ -66,8 +67,22 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
6667
# Use our wrapped sum to make sure it does upcasting correctly
6768
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
6869

69-
__all__ = linalg_all + ['outer', 'trace', 'matmul', 'matrix_transpose', 'tensordot',
70-
'vecdot', 'solve']
70+
def vector_norm(
71+
x: array,
72+
/,
73+
*,
74+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
75+
keepdims: bool = False,
76+
ord: Union[int, float, Literal[inf, -inf]] = 2,
77+
**kwargs,
78+
) -> array:
79+
# torch.vector_norm incorrectly treats axis=() the same as axis=None
80+
if axis == ():
81+
keepdims = True
82+
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
83+
84+
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
85+
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
7186

7287
_all_ignore = ['torch_linalg', 'sum']
7388

0 commit comments

Comments
 (0)